From 7fe6b8e171d23fe12847dbf42309d46144ea6407 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 15 Apr 2026 19:04:51 +0200 Subject: [PATCH] vulkan: optimize im2col (llama/21713) * vulkan: improve im2col memory write layout * cap workgroups * minimal device tuning * use vendor_id instead of subgroup size --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++- .../ggml-vulkan/vulkan-shaders/im2col.comp | 96 +++++++------------ 2 files changed, 46 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b2a54bd8..702a249d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 674f91e5..ba4c2103 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -13,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint linear_idx = base_linear_idx + idx; - - if (linear_idx >= p.pelements) { - continue; - } - - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; - - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } - - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; + if (chw_idx >= p.CHW) { + return; } - } - } - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint ic = chw_idx / KHKW; + const uint rem = chw_idx - ic * KHKW; + const uint ky = rem / p.KW; + const uint kx = rem - ky * p.KW; - const uint linear_idx = base_linear_idx + idx; + const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - if (linear_idx >= p.pelements) { - continue; - } + A_TYPE val = A_TYPE(0); + if (iih < p.IH && iiw < p.IW) { + val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); + out_ptr.d = D_TYPE(val); #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_row + chw_idx] = D_TYPE(val); #endif - } + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } }