ggml-webgpu: add the upscale shader (llama/22419)

* shader(upscale): add the upscale shader with nearest, bilinear and bicubic implementations

* shader(upscale): use macro
This commit is contained in:
Chen Yuan 2026-05-01 01:22:18 -04:00 committed by Georgi Gerganov
parent b34a9f3d83
commit ccd04522f9
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
3 changed files with 383 additions and 0 deletions

View File

@ -1,6 +1,7 @@
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
#define GGML_WEBGPU_SHADER_LIB_HPP
#include "ggml-impl.h"
#include "ggml-wgsl-shaders.hpp"
#include "ggml.h"
#include "pre_wgsl.hpp"
@ -405,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash {
}
};
/** Upscale **/
struct ggml_webgpu_upscale_pipeline_key {
ggml_type input_type;
ggml_type output_type;
uint32_t base_mode;
bool antialias;
bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
antialias == other.antialias;
}
};
struct ggml_webgpu_upscale_pipeline_key_hash {
size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.input_type);
ggml_webgpu_hash_combine(seed, key.output_type);
ggml_webgpu_hash_combine(seed, key.base_mode);
ggml_webgpu_hash_combine(seed, key.antialias);
return seed;
}
};
/** Concat **/
struct ggml_webgpu_concat_pipeline_key {
@ -1049,6 +1075,8 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline,
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
rms_norm_mul_pipelines;
std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
upscale_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -2947,6 +2975,72 @@ class ggml_webgpu_shader_lib {
return im2col_pipelines[key];
}
webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
const uint32_t base_mode = mode_flags & 0xFFu;
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
ggml_webgpu_upscale_pipeline_key key = {};
key.input_type = context.src0->type;
key.output_type = context.dst->type;
key.base_mode = base_mode;
key.antialias = antialias;
auto it = upscale_pipelines.find(key);
if (it != upscale_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "upscale";
if (key.input_type == GGML_TYPE_F16) {
defines.push_back("SRC_F16");
variant += "_src_f16";
} else {
variant += "_src_f32";
}
if (key.output_type == GGML_TYPE_F16) {
defines.push_back("DST_F16");
variant += "_dst_f16";
} else {
variant += "_dst_f32";
}
switch (base_mode) {
case GGML_SCALE_MODE_NEAREST:
defines.push_back("NEAREST");
variant += "_nearest";
break;
case GGML_SCALE_MODE_BILINEAR:
defines.push_back("BILINEAR");
variant += "_bilinear";
break;
case GGML_SCALE_MODE_BICUBIC:
defines.push_back("BICUBIC");
variant += "_bicubic";
break;
default:
GGML_ABORT("Unsupported upscale mode");
}
if (antialias) {
defines.push_back("ANTIALIAS");
variant += "_aa";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_upscale, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
upscale_pipelines[key] = pipeline;
return upscale_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -2824,6 +2824,49 @@ static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph,
return true;
}
static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) {
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0);
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
mode_flags };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) };
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
// Returns the encoded command, or std::nullopt if the operation is a no-op
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
ggml_cgraph * cgraph,
@ -2931,6 +2974,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
case GGML_OP_IM2COL:
return ggml_webgpu_im2col(ctx, src0, src1, node);
case GGML_OP_UPSCALE:
return ggml_webgpu_upscale(ctx, src0, node);
default:
return std::nullopt;
}
@ -4163,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SUM_ROWS:
supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
break;
case GGML_OP_UPSCALE:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
default:
break;
}

View File

@ -0,0 +1,240 @@
#if defined(SRC_F16) || defined(DST_F16)
enable f16;
#endif
#ifdef SRC_F16
#define SRC_TYPE f16
#else
#define SRC_TYPE f32
#endif
#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> input: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> output: array<DST_TYPE>;
struct Params {
offset_i: u32,
offset_o: u32,
// element strides
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
src_w: u32,
src_h: u32,
src_z: u32,
src_n: u32,
dst_w: u32,
dst_h: u32,
dst_z: u32,
dst_n: u32,
mode_flags: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u;
fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 {
let cx = u32(clamp(x, 0, i32(params.src_w) - 1));
let cy = u32(clamp(y, 0, i32(params.src_h) - 1));
let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3;
return f32(input[i]);
}
fn cubic_weight(t: f32, a: f32) -> f32 {
let at = abs(t);
if (at <= 1.0) {
return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0;
} else if (at <= 2.0) {
return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a;
} else {
return 0.0;
}
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y;
let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n;
if (i_out >= total) {
return;
}
// decode (x, y, z, n)
var i = i_out;
let x_dst = i % params.dst_w;
i = i / params.dst_w;
let y_dst = i % params.dst_h;
i = i / params.dst_h;
let z_dst = i % params.dst_z;
let n_dst = i / params.dst_z;
// scale factors
var sf0 = f32(params.dst_w) / f32(params.src_w);
var sf1 = f32(params.dst_h) / f32(params.src_h);
var sf2 = f32(params.dst_z) / f32(params.src_z);
var sf3 = f32(params.dst_n) / f32(params.src_n);
let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0;
// pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners
var pixel_offset = 0.5;
if (align_corners) {
pixel_offset = 0.0;
if (params.dst_w > 1 && params.src_w > 1) {
sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1);
}
if (params.dst_h > 1 && params.src_h > 1) {
sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1);
}
}
let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2)));
let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3)));
var result = 0.0;
#if defined(NEAREST)
let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0)));
let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1)));
result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src);
#elif defined(BILINEAR)
#if defined(ANTIALIAS)
// Antialiased bilinear: triangle filter over a variable support region.
let support0 = max(1.0f / sf0, 1.0f);
let support1 = max(1.0f / sf1, 1.0f);
let invscale0 = 1.0 / support0;
let invscale1 = 1.0 / support1;
let fx = (f32(x_dst) + pixel_offset) / sf0;
let fy = (f32(y_dst) + pixel_offset) / sf1;
let x_min = max(i32(fx - support0 + pixel_offset), 0);
let y_min = max(i32(fy - support1 + pixel_offset), 0);
let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w));
let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h));
var weighted_sum = 0.0;
var total_weight = 0.0;
for (var x = x_min; x < x_max; x += 1) {
let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0);
for (var y = y_min; y < y_max; y += 1) {
let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0);
let w = wx * wy;
if (w > 0.0) {
weighted_sum += get_clamped_input(x, y, z_src, n_src) * w;
total_weight += w;
}
}
}
if (total_weight > 0.0) {
result = weighted_sum / total_weight;
}
#else
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
let x0 = i32(floor(fx));
let y0 = i32(floor(fy));
let dx = clamp(fx - f32(x0), 0.0, 1.0);
let dy = clamp(fy - f32(y0), 0.0, 1.0);
let a = get_clamped_input(x0, y0, z_src, n_src);
let b = get_clamped_input(x0 + 1, y0, z_src, n_src);
let c = get_clamped_input(x0, y0 + 1, z_src, n_src);
let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
let wa = (1.0 - dx) * (1.0 - dy);
let wb = dx * (1.0 - dy);
let wc = (1.0 - dx) * dy;
let wd = dx * dy;
result = a * wa + b * wb + c * wc + d * wd;
#endif
#elif defined(BICUBIC)
// bicubic convolution with alpha = -0.75 (PyTorch default)
let alpha = -0.75;
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
let x0 = i32(floor(fx));
let y0 = i32(floor(fy));
let dx = fx - f32(x0);
let dy = fy - f32(y0);
// horizontal weights for offsets -1, 0, 1, 2
let wx0 = cubic_weight(dx + 1.0, alpha);
let wx1 = cubic_weight(dx, alpha);
let wx2 = cubic_weight(1.0 - dx, alpha);
let wx3 = cubic_weight(2.0 - dx, alpha);
// vertical weights for offsets -1, 0, 1, 2
let wy0 = cubic_weight(dy + 1.0, alpha);
let wy1 = cubic_weight(dy, alpha);
let wy2 = cubic_weight(1.0 - dy, alpha);
let wy3 = cubic_weight(2.0 - dy, alpha);
// intermediate horizontal interpolation for 4x4 grid of pixels
// x0-1, x0, x0+1, x0+2, y0-1
let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src);
let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src);
let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src);
let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src);
let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3;
// x0-1, x0, x0+1, x0+2, y0
let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src);
let q1 = get_clamped_input(x0, y0, z_src, n_src);
let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src);
let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src);
let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3;
// x0-1, x0, x0+1, x0+2, y0+1
let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src);
let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src);
let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src);
let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3;
// x0-1, x0, x0+1, x0+2, y0+2
let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src);
let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src);
let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src);
let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src);
let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3;
// final vertical interpolation
result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3;
#endif
let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3;
output[dst_idx] = DST_TYPE(result);
}