From d9115692308a0e006df94ff578c6f2b9c6fcd9ef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Jun 2026 20:36:05 +0300 Subject: [PATCH] metal : implement rope_back operator (llama/24725) Reuse existing rope kernels with a function constant to toggle forward/backward rotation, avoiding duplicate kernel code. Assisted-by: pi:llama.cpp/Qwen3.6-27B --- ggml/src/ggml-metal/ggml-metal-device.cpp | 7 +++++-- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 4 ++++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index cbc9459c9..0e1f1de45 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1703,7 +1703,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_ } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { - assert(op->op == GGML_OP_ROPE); + assert(op->op == GGML_OP_ROPE || op->op == GGML_OP_ROPE_BACK); + + const bool is_back = op->op == GGML_OP_ROPE_BACK; char base[256]; char name[256]; @@ -1727,13 +1729,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type)); } - snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); + snprintf(name, 256, "%s_imrope=%d_is_back=%d", base, is_imrope ? 1 : 0, is_back ? 1 : 0); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); + ggml_metal_cv_set_bool(cv, is_back, FC_ROPE + 1); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 26b0e77f2..a7cbc60eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: return true; case GGML_OP_IM2COL: return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index b13bf2711..18656b346 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -375,6 +375,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_norm(ctx, idx); } break; case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { n_fuse = ggml_metal_op_rope(ctx, idx); } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 97abecb47..31b5326ec 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4358,6 +4358,7 @@ template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_ #endif constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]]; +constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]]; static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); @@ -4381,6 +4382,9 @@ static void rope_yarn( } *cos_theta = cos(theta) * mscale; *sin_theta = sin(theta) * mscale; + if (FC_rope_is_back) { + *sin_theta *= -1.0f; + } } // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get