vulkan: handle rope with large number of rows (llama/18306)

This commit is contained in:
Jeff Bolz 2025-12-26 09:53:46 -06:00 committed by Georgi Gerganov
parent 33f75a88ac
commit 67473fef57
6 changed files with 31 additions and 7 deletions

View File

@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants {
struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
@ -9090,10 +9091,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
break;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
uint32_t nrows = (uint32_t)ggml_nrows(src0);
uint32_t z = 1;
if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
z = CEIL_DIV(nrows, 32768);
nrows = 32768;
}
elements = { nrows, (uint32_t)ne00, z };
} break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@ -10021,7 +10032,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
vk_op_rope_push_constants rope {
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_multi(i0, i1, pc);
}

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_neox(i0, i1, pc);
}

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_norm(i0, i1, pc);
}

View File

@ -6,6 +6,7 @@
struct rope_params {
uint rope_mode;
uint ncols;
uint nrows;
uint n_dims;
float freq_scale;
uint p_delta_rows;

View File

@ -6,6 +6,9 @@
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x;
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
return;
}
rope_vision(i0, i1, pc);
}