diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0f66275c..651c9cbc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,6 +1,7 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" @@ -405,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Upscale **/ + +struct ggml_webgpu_upscale_pipeline_key { + ggml_type input_type; + ggml_type output_type; + uint32_t base_mode; + bool antialias; + + bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode && + antialias == other.antialias; + } +}; + +struct ggml_webgpu_upscale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + ggml_webgpu_hash_combine(seed, key.base_mode); + ggml_webgpu_hash_combine(seed, key.antialias); + return seed; + } +}; + /** Concat **/ struct ggml_webgpu_concat_pipeline_key { @@ -1049,6 +1075,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_rms_norm_mul_pipeline_key_hash> rms_norm_mul_pipelines; + std::unordered_map + upscale_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2947,6 +2975,72 @@ class ggml_webgpu_shader_lib { return im2col_pipelines[key]; } + webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0); + const uint32_t base_mode = mode_flags & 0xFFu; + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u; + + ggml_webgpu_upscale_pipeline_key key = {}; + key.input_type = context.src0->type; + key.output_type = context.dst->type; + key.base_mode = base_mode; + key.antialias = antialias; + + auto it = upscale_pipelines.find(key); + if (it != upscale_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "upscale"; + + if (key.input_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } else { + variant += "_src_f32"; + } + + if (key.output_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } else { + variant += "_dst_f32"; + } + + switch (base_mode) { + case GGML_SCALE_MODE_NEAREST: + defines.push_back("NEAREST"); + variant += "_nearest"; + break; + case GGML_SCALE_MODE_BILINEAR: + defines.push_back("BILINEAR"); + variant += "_bilinear"; + break; + case GGML_SCALE_MODE_BICUBIC: + defines.push_back("BICUBIC"); + variant += "_bicubic"; + break; + default: + GGML_ABORT("Unsupported upscale mode"); + } + + if (antialias) { + defines.push_back("ANTIALIAS"); + variant += "_aa"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_upscale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + upscale_pipelines[key] = pipeline; + return upscale_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f102c7a8..cab0aead 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2824,6 +2824,49 @@ static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, return true; } +static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0); + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + mode_flags }; + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + // Returns the encoded command, or std::nullopt if the operation is a no-op static std::optional ggml_webgpu_encode(webgpu_context ctx, ggml_cgraph * cgraph, @@ -2931,6 +2974,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_conv_2d(ctx, src0, src1, node); case GGML_OP_IM2COL: return ggml_webgpu_im2col(ctx, src0, src1, node); + case GGML_OP_UPSCALE: + return ggml_webgpu_upscale(ctx, src0, node); default: return std::nullopt; } @@ -4163,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUM_ROWS: supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); break; + case GGML_OP_UPSCALE: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl new file mode 100644 index 00000000..e9ef8822 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl @@ -0,0 +1,240 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + src_w: u32, + src_h: u32, + src_z: u32, + src_n: u32, + + dst_w: u32, + dst_h: u32, + dst_z: u32, + dst_n: u32, + + mode_flags: u32, +}; + +@group(0) @binding(2) +var params: Params; + +const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; + +fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { + let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); + let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); + let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; + return f32(input[i]); +} + +fn cubic_weight(t: f32, a: f32) -> f32 { + let at = abs(t); + if (at <= 1.0) { + return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; + } else if (at <= 2.0) { + return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; + } else { + return 0.0; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; + let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; + + if (i_out >= total) { + return; + } + + // decode (x, y, z, n) + var i = i_out; + let x_dst = i % params.dst_w; + i = i / params.dst_w; + let y_dst = i % params.dst_h; + i = i / params.dst_h; + let z_dst = i % params.dst_z; + let n_dst = i / params.dst_z; + + // scale factors + var sf0 = f32(params.dst_w) / f32(params.src_w); + var sf1 = f32(params.dst_h) / f32(params.src_h); + var sf2 = f32(params.dst_z) / f32(params.src_z); + var sf3 = f32(params.dst_n) / f32(params.src_n); + + let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; + + // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners + var pixel_offset = 0.5; + if (align_corners) { + pixel_offset = 0.0; + if (params.dst_w > 1 && params.src_w > 1) { + sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); + } + if (params.dst_h > 1 && params.src_h > 1) { + sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); + } + } + + let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); + let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); + + var result = 0.0; + +#if defined(NEAREST) + + let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); + let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); + + result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); + +#elif defined(BILINEAR) + +#if defined(ANTIALIAS) + + // Antialiased bilinear: triangle filter over a variable support region. + let support0 = max(1.0f / sf0, 1.0f); + let support1 = max(1.0f / sf1, 1.0f); + let invscale0 = 1.0 / support0; + let invscale1 = 1.0 / support1; + + let fx = (f32(x_dst) + pixel_offset) / sf0; + let fy = (f32(y_dst) + pixel_offset) / sf1; + + let x_min = max(i32(fx - support0 + pixel_offset), 0); + let y_min = max(i32(fy - support1 + pixel_offset), 0); + let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); + let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); + + var weighted_sum = 0.0; + var total_weight = 0.0; + + for (var x = x_min; x < x_max; x += 1) { + let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); + for (var y = y_min; y < y_max; y += 1) { + let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); + let w = wx * wy; + if (w > 0.0) { + weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; + total_weight += w; + } + } + } + + if (total_weight > 0.0) { + result = weighted_sum / total_weight; + } + +#else + + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = clamp(fx - f32(x0), 0.0, 1.0); + let dy = clamp(fy - f32(y0), 0.0, 1.0); + let a = get_clamped_input(x0, y0, z_src, n_src); + let b = get_clamped_input(x0 + 1, y0, z_src, n_src); + let c = get_clamped_input(x0, y0 + 1, z_src, n_src); + let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + + let wa = (1.0 - dx) * (1.0 - dy); + let wb = dx * (1.0 - dy); + let wc = (1.0 - dx) * dy; + let wd = dx * dy; + + result = a * wa + b * wb + c * wc + d * wd; + +#endif + +#elif defined(BICUBIC) + + // bicubic convolution with alpha = -0.75 (PyTorch default) + let alpha = -0.75; + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = fx - f32(x0); + let dy = fy - f32(y0); + + // horizontal weights for offsets -1, 0, 1, 2 + let wx0 = cubic_weight(dx + 1.0, alpha); + let wx1 = cubic_weight(dx, alpha); + let wx2 = cubic_weight(1.0 - dx, alpha); + let wx3 = cubic_weight(2.0 - dx, alpha); + + // vertical weights for offsets -1, 0, 1, 2 + let wy0 = cubic_weight(dy + 1.0, alpha); + let wy1 = cubic_weight(dy, alpha); + let wy2 = cubic_weight(1.0 - dy, alpha); + let wy3 = cubic_weight(2.0 - dy, alpha); + + // intermediate horizontal interpolation for 4x4 grid of pixels + // x0-1, x0, x0+1, x0+2, y0-1 + let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); + let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); + let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); + let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); + let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0 + let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); + let q1 = get_clamped_input(x0, y0, z_src, n_src); + let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); + let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); + let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+1 + let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); + let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); + let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); + let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+2 + let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); + let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); + let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); + let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); + let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; + + // final vertical interpolation + result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; + +#endif + + let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; + output[dst_idx] = DST_TYPE(result); +}