metal : promote mul_mv/mul_mm batch divisors to function constants (llama/22711)

* metal : promote mul_mv/mul_mm batch divisors to function constants

* metal : take op directly in get_pipeline_mul_mv_ext
This commit is contained in:
guyfischman 2026-05-12 07:15:02 +02:00 committed by Georgi Gerganov
parent ea4652c427
commit 8ec91c91e1
4 changed files with 127 additions and 88 deletions

View File

@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const int ne12 = op->src[1]->ne[2];
const int r2 = ne12 / op->src[0]->ne[2];
const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3];
GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX);
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);
GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX);
const int16_t ne12 = (int16_t) op->src[1]->ne[2];
const int16_t ne13 = (int16_t) op->src[1]->ne[3];
const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]);
const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]);
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d",
base, bc_inp, bc_out, ne12, ne13, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2);
ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
}
};
GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX);
const int16_t r2 = (int16_t) (ne12 / ne02);
const int16_t r3 = (int16_t) (ne13 / ne03);
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

View File

@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

View File

@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg);
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,

View File

@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write(
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
template<typename block_q_type, short NR0, typename args_t>
void mul_vec_q_n_f32_impl(
@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl(
// pointers to src0 rows
device const block_q_type * ax[NR0];
FOR_UNROLL (int row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl(
device const block_q1_0 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
}
@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl(
// pointers to src0 rows
device const block_q8_0 * ax[NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const int i12 = i1m%FC_mul_mv_ne12;
const int i13 = i1m/FC_mul_mv_ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const int i12 = i1m%FC_mul_mv_ne12;
const int i13 = i1m/FC_mul_mv_ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const T0 * x = (device const T0 *) (src0 + offset0);
@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl(
// pointers to src0 rows
device const T0 * ax [NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
}
@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl(
const int r1 = tgpig.y;
const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
//const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl(
device const T0 * ax [NR0];
device const T04 * ax4[NR0];
FOR_UNROLL (short row = 0; row < NR0; ++row) {
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl(
return;
}
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
device const T0 * x = (device const T0 *) (src0 + offset0);
@ -7479,10 +7482,10 @@ void kernel_mul_mv_q2_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
@ -7584,10 +7587,10 @@ void kernel_mul_mv_q3_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
@ -7758,10 +7761,10 @@ void kernel_mul_mv_q4_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
@ -7870,10 +7873,10 @@ void kernel_mul_mv_q5_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
@ -8006,10 +8009,10 @@ void kernel_mul_mv_q6_K_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
@ -8111,10 +8114,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
@ -8219,10 +8222,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
@ -8338,10 +8341,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
@ -8450,10 +8453,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
@ -8562,10 +8565,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
@ -8675,10 +8678,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
@ -8774,10 +8777,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
@ -8883,10 +8886,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
@ -8992,10 +8995,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
@ -9103,10 +9106,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint i12 = im%FC_mul_mv_ne12;
const uint i13 = im/FC_mul_mv_ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
@ -9321,6 +9324,10 @@ kernel void kernel_diag_f32(
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
@ -9347,11 +9354,11 @@ kernel void kernel_mul_mm(
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % args.ne12;
const int i13 = im / args.ne12;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
@ -9490,10 +9497,10 @@ kernel void kernel_mul_mm(
short il = il0;
const int i12 = im%args.ne12;
const int i13 = im/args.ne12;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;