metal : implement rope_back operator (llama/24725)
Reuse existing rope kernels with a function constant to toggle forward/backward rotation, avoiding duplicate kernel code. Assisted-by: pi:llama.cpp/Qwen3.6-27B
This commit is contained in:
parent
ad19c98240
commit
d911569230
|
|
@ -1703,7 +1703,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_
|
|||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_ROPE);
|
||||
assert(op->op == GGML_OP_ROPE || op->op == GGML_OP_ROPE_BACK);
|
||||
|
||||
const bool is_back = op->op == GGML_OP_ROPE_BACK;
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
|
@ -1727,13 +1729,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
|
|||
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
||||
snprintf(name, 256, "%s_imrope=%d_is_back=%d", base, is_imrope ? 1 : 0, is_back ? 1 : 0);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
||||
ggml_metal_cv_set_bool(cv, is_back, FC_ROPE + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
|
|
|
|||
|
|
@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_RMS_NORM:
|
||||
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
|
|
|
|||
|
|
@ -375,6 +375,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
n_fuse = ggml_metal_op_norm(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
n_fuse = ggml_metal_op_rope(ctx, idx);
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -4358,6 +4358,7 @@ template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_
|
|||
#endif
|
||||
|
||||
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
|
|
@ -4381,6 +4382,9 @@ static void rope_yarn(
|
|||
}
|
||||
*cos_theta = cos(theta) * mscale;
|
||||
*sin_theta = sin(theta) * mscale;
|
||||
if (FC_rope_is_back) {
|
||||
*sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
|
|
|
|||
Loading…
Reference in New Issue