vulkan: add col2im_1d op (llama/24425)
* vulkan: add GGML_OP_COL2IM_1D, follow-up to the CPU op * vulkan: col2im_1d bounded gather loop instead of full-K scan with modulo * vulkan: col2im_1d address review from @jeffbolznv * vulkan: col2im_1d return nullptr for unsupported types, address review from @0cc4m
This commit is contained in:
parent
d77b2f704c
commit
c8f370a460
|
|
@ -902,6 +902,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_conv_transpose_1d_f32;
|
||||
vk_pipeline pipeline_col2im_1d_f32;
|
||||
vk_pipeline pipeline_col2im_1d_f16;
|
||||
vk_pipeline pipeline_col2im_1d_bf16;
|
||||
vk_pipeline pipeline_snake_f32;
|
||||
vk_pipeline pipeline_snake_f16;
|
||||
vk_pipeline pipeline_snake_bf16;
|
||||
|
|
@ -1552,6 +1555,16 @@ struct vk_op_timestep_embedding_push_constants {
|
|||
uint32_t max_period;
|
||||
};
|
||||
|
||||
struct vk_op_col2im_1d_push_constants {
|
||||
uint32_t T_out;
|
||||
uint32_t OC;
|
||||
uint32_t K_OC;
|
||||
uint32_t T_in;
|
||||
uint32_t K;
|
||||
int32_t stride;
|
||||
int32_t p0;
|
||||
};
|
||||
|
||||
struct vk_op_conv_transpose_1d_push_constants {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
|
|
@ -5203,6 +5216,9 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f32, "col2im_1d_f32", col2im_1d_f32_len, col2im_1d_f32_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f16, "col2im_1d_f16", col2im_1d_f16_len, col2im_1d_f16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_bf16, "col2im_1d_bf16", col2im_1d_bf16_len, col2im_1d_bf16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
|
@ -10702,6 +10718,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_conv_transpose_1d_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: return ctx->device->pipeline_col2im_1d_f32;
|
||||
case GGML_TYPE_F16: return ctx->device->pipeline_col2im_1d_f16;
|
||||
case GGML_TYPE_BF16: return ctx->device->pipeline_col2im_1d_bf16;
|
||||
default: return nullptr;
|
||||
}
|
||||
case GGML_OP_POOL_2D:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_pool2d_f32;
|
||||
|
|
@ -11147,6 +11170,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
{
|
||||
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
|
||||
} break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
{
|
||||
elements = { uint32_t(dst->ne[0]), uint32_t(dst->ne[1]), 1 };
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
const uint32_t N = dst->ne[3];
|
||||
|
|
@ -12936,6 +12963,32 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
|
|||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_col2im_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
// src0: [K_OC, T_in] columns from matmul
|
||||
// dst: [T_out, OC]
|
||||
|
||||
const int32_t stride = dst->op_params[0];
|
||||
const int32_t oc = dst->op_params[1];
|
||||
const int32_t p0 = dst->op_params[2];
|
||||
|
||||
const uint32_t K_OC = static_cast<uint32_t>(src0->ne[0]);
|
||||
const uint32_t T_in = static_cast<uint32_t>(src0->ne[1]);
|
||||
const uint32_t T_out = static_cast<uint32_t>(dst->ne[0]);
|
||||
const uint32_t OC = static_cast<uint32_t>(oc);
|
||||
const uint32_t K = K_OC / OC;
|
||||
|
||||
vk_op_col2im_1d_push_constants p{};
|
||||
p.T_out = T_out;
|
||||
p.OC = OC;
|
||||
p.K_OC = K_OC;
|
||||
p.T_in = T_in;
|
||||
p.K = K;
|
||||
p.stride = stride;
|
||||
p.p0 = p0;
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COL2IM_1D, std::move(p));
|
||||
}
|
||||
|
||||
// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
|
||||
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
|
||||
// dedicated kernel directly. The pattern is validated by
|
||||
|
|
@ -14423,6 +14476,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
ggml_vk_col2im_1d(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
|
||||
|
|
@ -17188,6 +17245,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 ||
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16) &&
|
||||
op->type == op->src[0]->type &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op);
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
{
|
||||
|
|
@ -18019,6 +18083,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
const int32_t p0 = tensor->op_params[1];
|
||||
const int32_t d0 = tensor->op_params[2];
|
||||
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
|
||||
} else if (tensor->op == GGML_OP_COL2IM_1D) {
|
||||
const int32_t stride = tensor->op_params[0];
|
||||
const int32_t oc = tensor->op_params[1];
|
||||
const int32_t p0 = tensor->op_params[2];
|
||||
tensor_clone = ggml_col2im_1d(ggml_ctx, src_clone[0], stride, oc, p0);
|
||||
} else if (tensor->op == GGML_OP_POOL_2D) {
|
||||
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
|
||||
const int32_t k0 = tensor->op_params[1];
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // columns: [K_OC, T_in]
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // output: [T_out, OC]
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t T_out;
|
||||
uint32_t OC;
|
||||
uint32_t K_OC;
|
||||
uint32_t T_in;
|
||||
uint32_t K;
|
||||
int32_t stride;
|
||||
int32_t p0;
|
||||
} p;
|
||||
|
||||
// Load A_TYPE to float
|
||||
float load_col(uint32_t idx) {
|
||||
#if defined(DATA_A_BF16)
|
||||
return bf16_to_fp32(uint32_t(data_a[idx]));
|
||||
#else
|
||||
return float(data_a[idx]);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Store float as D_TYPE
|
||||
void store_dst(uint32_t idx, float v) {
|
||||
#if defined(DATA_A_BF16)
|
||||
data_d[idx] = D_TYPE(fp32_to_bf16(v));
|
||||
#else
|
||||
data_d[idx] = D_TYPE(v);
|
||||
#endif
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint32_t t_out = gl_GlobalInvocationID.x;
|
||||
const uint32_t oc = gl_GlobalInvocationID.y;
|
||||
if (t_out >= p.T_out || oc >= p.OC) return;
|
||||
|
||||
const int32_t t_abs = int32_t(t_out) + p.p0; // absolute position in uncropped signal
|
||||
|
||||
// Gather: only the ceil(K/stride) columns that scatter into t_abs, no modulo
|
||||
int32_t t_in_min = (t_abs - int32_t(p.K) + p.stride) / p.stride;
|
||||
if (t_in_min < 0) t_in_min = 0;
|
||||
int32_t t_in_max = t_abs / p.stride;
|
||||
if (t_in_max >= int32_t(p.T_in)) t_in_max = int32_t(p.T_in) - 1;
|
||||
|
||||
float val = 0.0;
|
||||
for (int32_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
|
||||
int32_t k = t_abs - t_in * p.stride;
|
||||
// col layout: [K_OC, T_in], column index = oc * K + k
|
||||
uint32_t col_idx = (oc * p.K + uint32_t(k)) + uint32_t(t_in) * p.K_OC;
|
||||
val += load_col(col_idx);
|
||||
}
|
||||
|
||||
// dst layout: [T_out, OC], element (t_out, oc) = t_out + oc * T_out
|
||||
store_dst(t_out + oc * p.T_out, val);
|
||||
}
|
||||
|
|
@ -1003,6 +1003,9 @@ void process_shaders() {
|
|||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("col2im_1d_f32", "col2im_1d.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("col2im_1d_f16", "col2im_1d.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("col2im_1d_bf16", "col2im_1d.comp", {{"DATA_A_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
|
||||
|
||||
string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
|
|
|||
Loading…
Reference in New Issue