vulkan: implement ABS and NEG (llama/17245)

* docs: update Vulkan ops

* vulkan: add NEG op

* vulkan: add ABS op

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
This commit is contained in:
Giuseppe Scrivano 2025-11-15 12:00:29 +01:00 committed by Georgi Gerganov
parent e1846fc599
commit 4c4e663da0
4 changed files with 67 additions and 0 deletions

View File

@ -656,10 +656,12 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
vk_pipeline pipeline_hardsigmoid[2];
vk_pipeline pipeline_hardswish[2];
vk_pipeline pipeline_abs[2];
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
@ -3804,10 +3806,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
CREATE_UNARY(sigmoid)
CREATE_UNARY(hardsigmoid)
CREATE_UNARY(hardswish)
CREATE_UNARY(abs)
#undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \
@ -8170,6 +8174,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_NEG:
return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TANH:
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SIGMOID:
@ -8178,6 +8184,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_HARDSWISH:
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_ABS:
return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
default:
break;
}
@ -11106,10 +11114,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS:
break;
default:
return false;
@ -11436,10 +11446,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
default:
@ -11706,10 +11718,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS:
buf = tensor->buffer;
break;
default:
@ -13235,10 +13249,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_ABS:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@ -14116,6 +14132,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_NEG:
tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_TANH:
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
break;
@ -14128,6 +14147,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_HARDSWISH:
tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_ABS:
tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");

View File

@ -0,0 +1,21 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(abs(float(data_a[i])));
}

View File

@ -0,0 +1,20 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(-float(data_a[i]));
}

View File

@ -827,6 +827,8 @@ void process_shaders() {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@ -835,6 +837,8 @@ void process_shaders() {
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";