vulkan: Support unaligned tensors for ROPE (llama/22637)

This commit is contained in:
Jeff Bolz 2026-05-17 04:30:16 -05:00 committed by Georgi Gerganov
parent c7dd64c606
commit e417ce7aeb
3 changed files with 25 additions and 2 deletions

View File

@ -1354,6 +1354,8 @@ struct vk_op_rope_push_constants {
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t a_offset;
uint32_t d_offset;
};
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
@ -10126,6 +10128,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src3);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
template<typename PC>
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
@ -11270,6 +11281,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
(uint32_t)src0->ne[2],
nb01, nb02, nb03,
nb11, nb12, nb13,
0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets
};
return rope;
@ -11365,6 +11377,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(buf[i] != nullptr);
}
// a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned.
// Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset.
pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type);
offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1);
std::array<uint32_t, 3> elements;
elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };

View File

@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03,
// Per-row offset in shared memory
const uint ix = i0;
#else
const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
#endif
return ix;
}
@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_
idst = i1*p.nb11 + i0;
idst += rope_data_i[i2].x * p.set_rows_stride;
}
idst += p.d_offset;
if (i0 >= p.n_dims) {
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_
idst = i1*p.nb11 + i0/2;
idst += rope_data_i[i2].x * p.set_rows_stride;
}
idst += p.d_offset;
if (i0 >= p.n_dims) {
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope
idst = i1*p.nb11 + i0/2;
idst += rope_data_i[i2].x * p.set_rows_stride;
}
idst += p.d_offset;
if (i0 >= p.n_dims) {
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop
return;
}
const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
const int sect_dims = p.sections[0] + p.sections[1];

View File

@ -26,6 +26,9 @@ struct rope_params {
uint nb11;
uint nb12;
uint nb13;
uint a_offset;
uint d_offset;
};
#endif // !defined(GGML_ROPE_PARAMS)