CUDA: fuse relu + sqr (llama/22249)
This commit is contained in:
parent
393fdffe20
commit
b6b547885c
|
|
@ -3592,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * sqr = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != sqr->type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(unary->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
|
||||
const ggml_tensor *scale = cgraph->nodes[node_idx];
|
||||
|
|
@ -4100,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
|
||||
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) {
|
|||
return x * x;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_relu_sqr(float x) {
|
||||
const float r = fmaxf(x, 0.0f);
|
||||
return r * r;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_sqrt(float x) {
|
||||
return sqrtf(x);
|
||||
}
|
||||
|
|
@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary
|
|||
GGML_ABORT("Unsupported unary op for fused unary+mul");
|
||||
}
|
||||
}
|
||||
|
||||
/* fused relu + sqr */
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) {
|
||||
const ggml_tensor * src = relu_node->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src));
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src->type == sqr_node->type);
|
||||
|
||||
const int k = ggml_nelements(src);
|
||||
if (src->type == GGML_TYPE_F16) {
|
||||
unary_cuda<op_relu_sqr>((const half *)src->data, (half *)sqr_node->data, k, stream);
|
||||
} else {
|
||||
unary_cuda<op_relu_sqr>((const float *)src->data, (float *)sqr_node->data, k, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue