From c5cde8c7171dea86345338a7dee64d939b4c09cf Mon Sep 17 00:00:00 2001 From: l8bloom Date: Wed, 27 May 2026 16:59:08 +0200 Subject: [PATCH] vulkan: add REPEAT op support for f16 to f16. (llama/23298) * feat: extend repeat op for vulkan * feat: add repeat_f16 vulkan pipeline * fix: ensure same dst and src types * fix: use type_size instead of data types * fix: use int16 and int32 for repeat shader op * chore: rename repeat_f* to repeat_i* * chore: rename repeat vulkan pipelines --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 15 +++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f45b9cfd1..99b42f3bd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -768,7 +768,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; - vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -4708,9 +4709,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_UNARY(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9738,7 +9741,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { - return ctx->device->pipeline_repeat_f32; + return ctx->device->pipeline_repeat_i32; + } + if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) { + return ctx->device->pipeline_repeat_i16; } return nullptr; case GGML_OP_REPEAT_BACK: @@ -16253,7 +16259,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_REPEAT: - return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) && + (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2); case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a0aac3912..24b9d25f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -798,9 +798,11 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});