vulkan: add col2im_1d op (llama/24425)

* vulkan: add GGML_OP_COL2IM_1D, follow-up to the CPU op

* vulkan: col2im_1d bounded gather loop instead of full-K scan with modulo

* vulkan: col2im_1d address review from @jeffbolznv

* vulkan: col2im_1d return nullptr for unsupported types, address review from @0cc4m
This commit is contained in:
Pascal 2026-06-16 06:34:43 +02:00 committed by Georgi Gerganov
parent d77b2f704c
commit c8f370a460
3 changed files with 133 additions and 0 deletions

View File

@ -902,6 +902,9 @@ struct vk_device_struct {
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_col2im_1d_f32;
vk_pipeline pipeline_col2im_1d_f16;
vk_pipeline pipeline_col2im_1d_bf16;
vk_pipeline pipeline_snake_f32;
vk_pipeline pipeline_snake_f16;
vk_pipeline pipeline_snake_bf16;
@ -1552,6 +1555,16 @@ struct vk_op_timestep_embedding_push_constants {
uint32_t max_period;
};
struct vk_op_col2im_1d_push_constants {
uint32_t T_out;
uint32_t OC;
uint32_t K_OC;
uint32_t T_in;
uint32_t K;
int32_t stride;
int32_t p0;
};
struct vk_op_conv_transpose_1d_push_constants {
uint32_t Cout;
uint32_t Cin;
@ -5203,6 +5216,9 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f32, "col2im_1d_f32", col2im_1d_f32_len, col2im_1d_f32_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f16, "col2im_1d_f16", col2im_1d_f16_len, col2im_1d_f16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_bf16, "col2im_1d_bf16", col2im_1d_bf16_len, col2im_1d_bf16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
@ -10702,6 +10718,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_conv_transpose_1d_f32;
}
return nullptr;
case GGML_OP_COL2IM_1D:
switch (src0->type) {
case GGML_TYPE_F32: return ctx->device->pipeline_col2im_1d_f32;
case GGML_TYPE_F16: return ctx->device->pipeline_col2im_1d_f16;
case GGML_TYPE_BF16: return ctx->device->pipeline_col2im_1d_bf16;
default: return nullptr;
}
case GGML_OP_POOL_2D:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_pool2d_f32;
@ -11147,6 +11170,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
{
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
} break;
case GGML_OP_COL2IM_1D:
{
elements = { uint32_t(dst->ne[0]), uint32_t(dst->ne[1]), 1 };
} break;
case GGML_OP_POOL_2D:
{
const uint32_t N = dst->ne[3];
@ -12936,6 +12963,32 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
}
static void ggml_vk_col2im_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
// src0: [K_OC, T_in] columns from matmul
// dst: [T_out, OC]
const int32_t stride = dst->op_params[0];
const int32_t oc = dst->op_params[1];
const int32_t p0 = dst->op_params[2];
const uint32_t K_OC = static_cast<uint32_t>(src0->ne[0]);
const uint32_t T_in = static_cast<uint32_t>(src0->ne[1]);
const uint32_t T_out = static_cast<uint32_t>(dst->ne[0]);
const uint32_t OC = static_cast<uint32_t>(oc);
const uint32_t K = K_OC / OC;
vk_op_col2im_1d_push_constants p{};
p.T_out = T_out;
p.OC = OC;
p.K_OC = K_OC;
p.T_in = T_in;
p.K = K;
p.stride = stride;
p.p0 = p0;
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COL2IM_1D, std::move(p));
}
// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
// dedicated kernel directly. The pattern is validated by
@ -14423,6 +14476,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
break;
case GGML_OP_COL2IM_1D:
ggml_vk_col2im_1d(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
@ -17188,6 +17245,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_COL2IM_1D:
return (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16) &&
op->type == op->src[0]->type &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op);
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
{
@ -18019,6 +18083,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t p0 = tensor->op_params[1];
const int32_t d0 = tensor->op_params[2];
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
} else if (tensor->op == GGML_OP_COL2IM_1D) {
const int32_t stride = tensor->op_params[0];
const int32_t oc = tensor->op_params[1];
const int32_t p0 = tensor->op_params[2];
tensor_clone = ggml_col2im_1d(ggml_ctx, src_clone[0], stride, oc, p0);
} else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
const int32_t k0 = tensor->op_params[1];

View File

@ -0,0 +1,61 @@
#version 450
#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // columns: [K_OC, T_in]
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // output: [T_out, OC]
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (push_constant) uniform parameter {
uint32_t T_out;
uint32_t OC;
uint32_t K_OC;
uint32_t T_in;
uint32_t K;
int32_t stride;
int32_t p0;
} p;
// Load A_TYPE to float
float load_col(uint32_t idx) {
#if defined(DATA_A_BF16)
return bf16_to_fp32(uint32_t(data_a[idx]));
#else
return float(data_a[idx]);
#endif
}
// Store float as D_TYPE
void store_dst(uint32_t idx, float v) {
#if defined(DATA_A_BF16)
data_d[idx] = D_TYPE(fp32_to_bf16(v));
#else
data_d[idx] = D_TYPE(v);
#endif
}
void main() {
const uint32_t t_out = gl_GlobalInvocationID.x;
const uint32_t oc = gl_GlobalInvocationID.y;
if (t_out >= p.T_out || oc >= p.OC) return;
const int32_t t_abs = int32_t(t_out) + p.p0; // absolute position in uncropped signal
// Gather: only the ceil(K/stride) columns that scatter into t_abs, no modulo
int32_t t_in_min = (t_abs - int32_t(p.K) + p.stride) / p.stride;
if (t_in_min < 0) t_in_min = 0;
int32_t t_in_max = t_abs / p.stride;
if (t_in_max >= int32_t(p.T_in)) t_in_max = int32_t(p.T_in) - 1;
float val = 0.0;
for (int32_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
int32_t k = t_abs - t_in * p.stride;
// col layout: [K_OC, T_in], column index = oc * K + k
uint32_t col_idx = (oc * p.K + uint32_t(k)) + uint32_t(t_in) * p.K_OC;
val += load_col(col_idx);
}
// dst layout: [T_out, OC], element (t_out, oc) = t_out + oc * T_out
store_dst(t_out + oc * p.T_out, val);
}

View File

@ -1003,6 +1003,9 @@ void process_shaders() {
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("col2im_1d_f32", "col2im_1d.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("col2im_1d_f16", "col2im_1d.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("col2im_1d_bf16", "col2im_1d.comp", {{"DATA_A_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});