ggml-webgpu: add the upscale shader (llama/22419)
* shader(upscale): add the upscale shader with nearest, bilinear and bicubic implementations * shader(upscale): use macro
This commit is contained in:
parent
b34a9f3d83
commit
ccd04522f9
|
|
@ -1,6 +1,7 @@
|
|||
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
||||
#define GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
#include "ggml.h"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
|
@ -405,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** Upscale **/
|
||||
|
||||
struct ggml_webgpu_upscale_pipeline_key {
|
||||
ggml_type input_type;
|
||||
ggml_type output_type;
|
||||
uint32_t base_mode;
|
||||
bool antialias;
|
||||
|
||||
bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
|
||||
return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
|
||||
antialias == other.antialias;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_upscale_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.input_type);
|
||||
ggml_webgpu_hash_combine(seed, key.output_type);
|
||||
ggml_webgpu_hash_combine(seed, key.base_mode);
|
||||
ggml_webgpu_hash_combine(seed, key.antialias);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Concat **/
|
||||
|
||||
struct ggml_webgpu_concat_pipeline_key {
|
||||
|
|
@ -1049,6 +1075,8 @@ class ggml_webgpu_shader_lib {
|
|||
webgpu_pipeline,
|
||||
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
|
||||
rms_norm_mul_pipelines;
|
||||
std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
|
||||
upscale_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
|
|
@ -2947,6 +2975,72 @@ class ggml_webgpu_shader_lib {
|
|||
return im2col_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
|
||||
const uint32_t base_mode = mode_flags & 0xFFu;
|
||||
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
|
||||
|
||||
ggml_webgpu_upscale_pipeline_key key = {};
|
||||
key.input_type = context.src0->type;
|
||||
key.output_type = context.dst->type;
|
||||
key.base_mode = base_mode;
|
||||
key.antialias = antialias;
|
||||
|
||||
auto it = upscale_pipelines.find(key);
|
||||
if (it != upscale_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "upscale";
|
||||
|
||||
if (key.input_type == GGML_TYPE_F16) {
|
||||
defines.push_back("SRC_F16");
|
||||
variant += "_src_f16";
|
||||
} else {
|
||||
variant += "_src_f32";
|
||||
}
|
||||
|
||||
if (key.output_type == GGML_TYPE_F16) {
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dst_f16";
|
||||
} else {
|
||||
variant += "_dst_f32";
|
||||
}
|
||||
|
||||
switch (base_mode) {
|
||||
case GGML_SCALE_MODE_NEAREST:
|
||||
defines.push_back("NEAREST");
|
||||
variant += "_nearest";
|
||||
break;
|
||||
case GGML_SCALE_MODE_BILINEAR:
|
||||
defines.push_back("BILINEAR");
|
||||
variant += "_bilinear";
|
||||
break;
|
||||
case GGML_SCALE_MODE_BICUBIC:
|
||||
defines.push_back("BICUBIC");
|
||||
variant += "_bicubic";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported upscale mode");
|
||||
}
|
||||
|
||||
if (antialias) {
|
||||
defines.push_back("ANTIALIAS");
|
||||
variant += "_aa";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_upscale, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
upscale_pipelines[key] = pipeline;
|
||||
return upscale_pipelines[key];
|
||||
}
|
||||
|
||||
private:
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
std::string shader_code,
|
||||
|
|
|
|||
|
|
@ -2824,6 +2824,49 @@ static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0);
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src->ne[0],
|
||||
(uint32_t) src->ne[1],
|
||||
(uint32_t) src->ne[2],
|
||||
(uint32_t) src->ne[3],
|
||||
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
|
||||
mode_flags };
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) };
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
// Returns the encoded command, or std::nullopt if the operation is a no-op
|
||||
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
ggml_cgraph * cgraph,
|
||||
|
|
@ -2931,6 +2974,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
|||
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
||||
case GGML_OP_IM2COL:
|
||||
return ggml_webgpu_im2col(ctx, src0, src1, node);
|
||||
case GGML_OP_UPSCALE:
|
||||
return ggml_webgpu_upscale(ctx, src0, node);
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
@ -4163,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_SUM_ROWS:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,240 @@
|
|||
#if defined(SRC_F16) || defined(DST_F16)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
#ifdef SRC_F16
|
||||
#define SRC_TYPE f16
|
||||
#else
|
||||
#define SRC_TYPE f32
|
||||
#endif
|
||||
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> input: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> output: array<DST_TYPE>;
|
||||
|
||||
struct Params {
|
||||
offset_i: u32,
|
||||
offset_o: u32,
|
||||
|
||||
// element strides
|
||||
si0: u32, si1: u32, si2: u32, si3: u32,
|
||||
so0: u32, so1: u32, so2: u32, so3: u32,
|
||||
|
||||
src_w: u32,
|
||||
src_h: u32,
|
||||
src_z: u32,
|
||||
src_n: u32,
|
||||
|
||||
dst_w: u32,
|
||||
dst_h: u32,
|
||||
dst_z: u32,
|
||||
dst_n: u32,
|
||||
|
||||
mode_flags: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u;
|
||||
|
||||
fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 {
|
||||
let cx = u32(clamp(x, 0, i32(params.src_w) - 1));
|
||||
let cy = u32(clamp(y, 0, i32(params.src_h) - 1));
|
||||
let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3;
|
||||
return f32(input[i]);
|
||||
}
|
||||
|
||||
fn cubic_weight(t: f32, a: f32) -> f32 {
|
||||
let at = abs(t);
|
||||
if (at <= 1.0) {
|
||||
return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0;
|
||||
} else if (at <= 2.0) {
|
||||
return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a;
|
||||
} else {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
|
||||
let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y;
|
||||
let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n;
|
||||
|
||||
if (i_out >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
// decode (x, y, z, n)
|
||||
var i = i_out;
|
||||
let x_dst = i % params.dst_w;
|
||||
i = i / params.dst_w;
|
||||
let y_dst = i % params.dst_h;
|
||||
i = i / params.dst_h;
|
||||
let z_dst = i % params.dst_z;
|
||||
let n_dst = i / params.dst_z;
|
||||
|
||||
// scale factors
|
||||
var sf0 = f32(params.dst_w) / f32(params.src_w);
|
||||
var sf1 = f32(params.dst_h) / f32(params.src_h);
|
||||
var sf2 = f32(params.dst_z) / f32(params.src_z);
|
||||
var sf3 = f32(params.dst_n) / f32(params.src_n);
|
||||
|
||||
let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0;
|
||||
|
||||
// pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners
|
||||
var pixel_offset = 0.5;
|
||||
if (align_corners) {
|
||||
pixel_offset = 0.0;
|
||||
if (params.dst_w > 1 && params.src_w > 1) {
|
||||
sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1);
|
||||
}
|
||||
if (params.dst_h > 1 && params.src_h > 1) {
|
||||
sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1);
|
||||
}
|
||||
}
|
||||
|
||||
let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2)));
|
||||
let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3)));
|
||||
|
||||
var result = 0.0;
|
||||
|
||||
#if defined(NEAREST)
|
||||
|
||||
let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0)));
|
||||
let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1)));
|
||||
|
||||
result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src);
|
||||
|
||||
#elif defined(BILINEAR)
|
||||
|
||||
#if defined(ANTIALIAS)
|
||||
|
||||
// Antialiased bilinear: triangle filter over a variable support region.
|
||||
let support0 = max(1.0f / sf0, 1.0f);
|
||||
let support1 = max(1.0f / sf1, 1.0f);
|
||||
let invscale0 = 1.0 / support0;
|
||||
let invscale1 = 1.0 / support1;
|
||||
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1;
|
||||
|
||||
let x_min = max(i32(fx - support0 + pixel_offset), 0);
|
||||
let y_min = max(i32(fy - support1 + pixel_offset), 0);
|
||||
let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w));
|
||||
let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h));
|
||||
|
||||
var weighted_sum = 0.0;
|
||||
var total_weight = 0.0;
|
||||
|
||||
for (var x = x_min; x < x_max; x += 1) {
|
||||
let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0);
|
||||
for (var y = y_min; y < y_max; y += 1) {
|
||||
let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0);
|
||||
let w = wx * wy;
|
||||
if (w > 0.0) {
|
||||
weighted_sum += get_clamped_input(x, y, z_src, n_src) * w;
|
||||
total_weight += w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0) {
|
||||
result = weighted_sum / total_weight;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
|
||||
let x0 = i32(floor(fx));
|
||||
let y0 = i32(floor(fy));
|
||||
let dx = clamp(fx - f32(x0), 0.0, 1.0);
|
||||
let dy = clamp(fy - f32(y0), 0.0, 1.0);
|
||||
let a = get_clamped_input(x0, y0, z_src, n_src);
|
||||
let b = get_clamped_input(x0 + 1, y0, z_src, n_src);
|
||||
let c = get_clamped_input(x0, y0 + 1, z_src, n_src);
|
||||
let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
|
||||
|
||||
let wa = (1.0 - dx) * (1.0 - dy);
|
||||
let wb = dx * (1.0 - dy);
|
||||
let wc = (1.0 - dx) * dy;
|
||||
let wd = dx * dy;
|
||||
|
||||
result = a * wa + b * wb + c * wc + d * wd;
|
||||
|
||||
#endif
|
||||
|
||||
#elif defined(BICUBIC)
|
||||
|
||||
// bicubic convolution with alpha = -0.75 (PyTorch default)
|
||||
let alpha = -0.75;
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
|
||||
|
||||
let x0 = i32(floor(fx));
|
||||
let y0 = i32(floor(fy));
|
||||
let dx = fx - f32(x0);
|
||||
let dy = fy - f32(y0);
|
||||
|
||||
// horizontal weights for offsets -1, 0, 1, 2
|
||||
let wx0 = cubic_weight(dx + 1.0, alpha);
|
||||
let wx1 = cubic_weight(dx, alpha);
|
||||
let wx2 = cubic_weight(1.0 - dx, alpha);
|
||||
let wx3 = cubic_weight(2.0 - dx, alpha);
|
||||
|
||||
// vertical weights for offsets -1, 0, 1, 2
|
||||
let wy0 = cubic_weight(dy + 1.0, alpha);
|
||||
let wy1 = cubic_weight(dy, alpha);
|
||||
let wy2 = cubic_weight(1.0 - dy, alpha);
|
||||
let wy3 = cubic_weight(2.0 - dy, alpha);
|
||||
|
||||
// intermediate horizontal interpolation for 4x4 grid of pixels
|
||||
// x0-1, x0, x0+1, x0+2, y0-1
|
||||
let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src);
|
||||
let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src);
|
||||
let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src);
|
||||
let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src);
|
||||
let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0
|
||||
let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src);
|
||||
let q1 = get_clamped_input(x0, y0, z_src, n_src);
|
||||
let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src);
|
||||
let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src);
|
||||
let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0+1
|
||||
let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src);
|
||||
let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src);
|
||||
let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
|
||||
let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src);
|
||||
let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0+2
|
||||
let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src);
|
||||
let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src);
|
||||
let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src);
|
||||
let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src);
|
||||
let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3;
|
||||
|
||||
// final vertical interpolation
|
||||
result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3;
|
||||
|
||||
#endif
|
||||
|
||||
let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3;
|
||||
output[dst_idx] = DST_TYPE(result);
|
||||
}
|
||||
Loading…
Reference in New Issue