diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e492c212..16ebc32c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -98,6 +98,29 @@ struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t tokens_per_wg; }; +struct ggml_webgpu_ssm_scan_pipeline_key { + int type; + int d_state; + + bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { + return type == other.type && d_state == other.d_state; + } +}; + +struct ggml_webgpu_ssm_scan_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.d_state); + return seed; + } +}; + +struct ggml_webgpu_ssm_scan_shader_decisions { + uint32_t wg_size; + uint32_t tokens_per_tile; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -921,6 +944,8 @@ class ggml_webgpu_shader_lib { solve_tri_pipelines; // type std::unordered_map ssm_conv_pipelines; // type/vectorized + std::unordered_map + ssm_scan_pipelines; // type/d_state std::unordered_map @@ -1433,6 +1458,53 @@ class ggml_webgpu_shader_lib { return ssm_conv_pipelines[key]; } + webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_scan_pipeline_key key = {}; + key.type = context.dst->type; + key.d_state = (int) context.src0->ne[0]; + + auto it = ssm_scan_pipelines.find(key); + if (it != ssm_scan_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_scan"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_scan shader"); + } + + const uint32_t wg_size = (uint32_t) key.d_state; + + constexpr uint32_t tokens_per_tile = 4u; + + defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u"); + defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u"); + + if (context.supports_subgroups) { + defines.push_back("USE_SUBGROUP_REDUCTION"); + variant += "_sg_reduce"; + } else { + variant += "_wg_reduce"; + } + + variant += "_d" + std::to_string(key.d_state); + + auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tokens_per_tile = tokens_per_tile; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_scan_pipelines[key] = pipeline; + return ssm_scan_pipelines[key]; + } + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_gated_delta_net_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 7ed6fdd1..bcec20c1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1115,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * src6, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + + (uint32_t) src3->ne[0], + (uint32_t) (src3->nb[1] / ggml_type_size(src3->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) (src5->nb[1] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[2] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[3] / ggml_type_size(src5->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src4->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + (uint32_t) ggml_nelements(src1), + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + }; + + const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x; + uint32_t wg_y; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2764,6 +2838,9 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_SSM_SCAN: + return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6], + node); case GGML_OP_GATED_DELTA_NET: return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: @@ -2822,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & } #endif +// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support +// models that would require it right now. static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { +#ifdef GGML_WEBGPU_CHECK_SET_ROWS wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, ctx->set_rows_host_error_buf.GetSize()); @@ -2835,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); } ctx->set_rows_host_error_buf.Unmap(); +#else + GGML_UNUSED(ctx); + GGML_UNUSED(num_inflight_batches); +#endif } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -2920,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } - ggml_backend_webgpu_wait_queue(ctx->global_ctx); - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3941,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SSM_SCAN: + supports_op = op->type == GGML_TYPE_F32 && + src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + break; case GGML_OP_GATED_DELTA_NET: { const uint32_t s_v = (uint32_t) src2->ne[0]; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl new file mode 100644 index 00000000..64324738 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -0,0 +1,168 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif + +struct Params { + offset_s: u32, + offset_x: u32, + offset_dt: u32, + offset_A: u32, + offset_B: u32, + offset_C: u32, + offset_ids: u32, + offset_dst: u32, + + stride_s1: u32, + stride_s2: u32, + stride_s3: u32, + + stride_x1: u32, + stride_x2: u32, + stride_x3: u32, + + stride_dt1: u32, + stride_dt2: u32, + + a_ne0: u32, + stride_A1: u32, + + stride_B1: u32, + stride_B2: u32, + stride_B3: u32, + + stride_C1: u32, + stride_C2: u32, + stride_C3: u32, + + d_state: u32, + d_inner: u32, + n_head: u32, + n_group: u32, + n_seq_tokens: u32, + n_seqs: u32, + + y_elems: u32, +}; + +@group(0) @binding(0) var s_in: array; +@group(0) @binding(1) var x: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var B: array; +@group(0) @binding(5) var C: array; +@group(0) @binding(6) var ids: array; +@group(0) @binding(7) var dst: array; +@group(0) @binding(8) var params: Params; + +var shared_x_dt: array; +var shared_dtsp: array; +var shared_reduce: array; + +fn reduce_base(token_in_tile: u32) -> u32 { + return token_in_tile * WG_SIZE; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let tid = local_id.x; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let i1 = wg_linear % params.d_inner; + let head_seq = wg_linear / params.d_inner; + let ir = head_seq % params.n_head; + let i3 = head_seq / params.n_head; + + let state_slot = u32(ids[params.offset_ids + i3]); + let g = ir / (params.n_head / params.n_group); + + let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3; + var s_prev = s_in[s_idx]; + + let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1]; + + for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) { + if (tid < TOKENS_PER_TILE) { + let token = token_base + tid; + if (token < params.n_seq_tokens) { + let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3; + let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2; + let dt0 = dt[dt_idx]; + let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); + shared_dtsp[tid] = dtsp; + shared_x_dt[tid] = x[x_idx] * dtsp; + } + } + + workgroupBarrier(); + + for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) { + let token = token_base + token_in_tile; + if (token >= params.n_seq_tokens) { + break; + } + + let x_dt = shared_x_dt[token_in_tile]; + let dA = exp(shared_dtsp[token_in_tile] * A0); + let reduce_idx = reduce_base(token_in_tile) + tid; + + let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; + let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; + let s = s_prev * dA + B[b_idx] * x_dt; + s_prev = s; + +#ifdef USE_SUBGROUP_REDUCTION + let subgroup_partial = subgroupAdd(s * C[c_idx]); + if (subgroup_invocation_id == 0u) { + shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; + } +#else + shared_reduce[reduce_idx] = s * C[c_idx]; +#endif + + workgroupBarrier(); + +#ifdef USE_SUBGROUP_REDUCTION + if (tid == 0u) { + var sum = 0.0; + for (var sg = 0u; sg < num_subgroups; sg++) { + sum += shared_reduce[reduce_base(token_in_tile) + sg]; + } + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = sum; + } +#else + for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) { + if (tid < stride) { + shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = shared_reduce[reduce_base(token_in_tile)]; + } +#endif + + workgroupBarrier(); + } + } + + let state_idx = + params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) + + i3 * (params.d_state * params.d_inner * params.n_head); + dst[state_idx] = s_prev; +}