vulkan: Reduce temporary memory usage for TOP_K (llama/17623)
- Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB.
This commit is contained in:
parent
fffdf679d4
commit
86cb5ab93f
|
|
@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
|
||||||
uint32_t orig_ncols;
|
uint32_t orig_ncols;
|
||||||
uint32_t ncols_input;
|
uint32_t ncols_input;
|
||||||
uint32_t ncols_output;
|
uint32_t ncols_output;
|
||||||
|
uint32_t k;
|
||||||
uint32_t nrows;
|
uint32_t nrows;
|
||||||
uint32_t first_pass;
|
uint32_t first_pass;
|
||||||
uint32_t last_pass;
|
uint32_t last_pass;
|
||||||
|
|
@ -1673,6 +1674,14 @@ class vk_perf_logger {
|
||||||
timings[name.str()].push_back(time);
|
timings[name.str()].push_back(time);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (node->op == GGML_OP_TOP_K) {
|
||||||
|
std::stringstream name;
|
||||||
|
name << ggml_op_name(node->op) <<
|
||||||
|
" K=" << node->ne[0] <<
|
||||||
|
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
||||||
|
timings[name.str()].push_back(time);
|
||||||
|
return;
|
||||||
|
}
|
||||||
timings[ggml_op_name(node->op)].push_back(time);
|
timings[ggml_op_name(node->op)].push_back(time);
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
|
|
@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
uint32_t nrows = ggml_nrows(src0);
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
uint32_t k = dst->ne[0];
|
uint32_t k = dst->ne[0];
|
||||||
|
|
||||||
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
|
||||||
|
|
||||||
// Reserve space for ivec2 per element, double buffered
|
|
||||||
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
|
||||||
const size_t x_sz = dbl_buf_size * 2;
|
|
||||||
uint32_t dbl_buf_index = 0;
|
|
||||||
|
|
||||||
if (ctx->prealloc_size_x < x_sz) {
|
|
||||||
ctx->prealloc_size_x = x_sz;
|
|
||||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
||||||
}
|
|
||||||
if (ctx->prealloc_x_need_sync) {
|
if (ctx->prealloc_x_need_sync) {
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
}
|
}
|
||||||
|
|
@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
// largest elements. Repeat until we have the top K elements.
|
// largest elements. Repeat until we have the top K elements.
|
||||||
// Need to do at least one iteration to write out the results.
|
// Need to do at least one iteration to write out the results.
|
||||||
bool done_one_iter = false;
|
bool done_one_iter = false;
|
||||||
|
uint32_t dbl_buf_index = 0;
|
||||||
|
size_t dbl_buf_size;
|
||||||
while (num_elements > k || !done_one_iter) {
|
while (num_elements > k || !done_one_iter) {
|
||||||
done_one_iter = true;
|
|
||||||
|
|
||||||
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||||
// But if K is larger, then we need a larger workgroup
|
// But if K is larger, then we need a larger workgroup
|
||||||
|
|
@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
// Number of elements remaining after this pass
|
// Number of elements remaining after this pass
|
||||||
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||||
|
|
||||||
|
pc2.ncols_output = num_dst_elements;
|
||||||
|
|
||||||
|
if (!done_one_iter) {
|
||||||
|
// Reserve space for ivec2 per element, double buffered
|
||||||
|
// K per workgroup per row
|
||||||
|
dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
|
||||||
|
dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
const size_t x_sz = dbl_buf_size * 2;
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_x < x_sz) {
|
||||||
|
ctx->prealloc_size_x = x_sz;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
vk_subbuffer src_buf;
|
vk_subbuffer src_buf;
|
||||||
vk_subbuffer dst_buf;
|
vk_subbuffer dst_buf;
|
||||||
|
|
||||||
|
|
@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
if (num_elements > k) {
|
if (num_elements > k) {
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
}
|
}
|
||||||
|
done_one_iter = true;
|
||||||
}
|
}
|
||||||
ctx->prealloc_x_need_sync = true;
|
ctx->prealloc_x_need_sync = true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ layout (push_constant) uniform parameter {
|
||||||
uint orig_ncols;
|
uint orig_ncols;
|
||||||
uint ncols_input;
|
uint ncols_input;
|
||||||
uint ncols_output;
|
uint ncols_output;
|
||||||
|
uint k;
|
||||||
uint nrows;
|
uint nrows;
|
||||||
uint first_pass;
|
uint first_pass;
|
||||||
uint last_pass;
|
uint last_pass;
|
||||||
|
|
@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) {
|
||||||
const uint row_offset = row * p.ncols_input;
|
const uint row_offset = row * p.ncols_input;
|
||||||
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
} else {
|
} else {
|
||||||
const uint row_offset = row * p.orig_ncols;
|
const uint row_offset = row * p.ncols_input;
|
||||||
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) {
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
if (p.ncols_output == 1) {
|
if (p.k == 1) {
|
||||||
// Fast path for single output - just do a max reduction
|
// Fast path for single output - just do a max reduction
|
||||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
if (col < s) {
|
if (col < s) {
|
||||||
|
|
@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
if (col < p.k) {
|
||||||
if (p.last_pass != 0) {
|
if (p.last_pass != 0) {
|
||||||
const uint row_offset = row * p.ncols_output;
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
data_d[row_offset + col] = dst_row[col].x;
|
const uint row_offset = row * p.k;
|
||||||
|
data_d[row_offset + col] = dst_row[col].x;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
|
||||||
data_t[row_offset + col] = dst_row[col];
|
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
|
||||||
|
data_t[row_offset + col] = dst_row[col];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
|
||||||
uint orig_ncols;
|
uint orig_ncols;
|
||||||
uint ncols_input;
|
uint ncols_input;
|
||||||
uint ncols_output;
|
uint ncols_output;
|
||||||
|
uint k;
|
||||||
uint nrows;
|
uint nrows;
|
||||||
uint first_pass;
|
uint first_pass;
|
||||||
uint last_pass;
|
uint last_pass;
|
||||||
|
|
@ -60,7 +61,7 @@ void topk(const uint row) {
|
||||||
const uint row_offset = row * p.ncols_input;
|
const uint row_offset = row * p.ncols_input;
|
||||||
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
} else {
|
} else {
|
||||||
const uint row_offset = row * p.orig_ncols;
|
const uint row_offset = row * p.ncols_input;
|
||||||
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -68,7 +69,7 @@ void topk(const uint row) {
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
if (p.ncols_output == 1) {
|
if (p.k == 1) {
|
||||||
// Fast path for single output - just do a max reduction
|
// Fast path for single output - just do a max reduction
|
||||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
|
|
@ -98,7 +99,7 @@ void topk(const uint row) {
|
||||||
uint range_max = 0xFF800000;
|
uint range_max = 0xFF800000;
|
||||||
// How many are above the current range, and how many we need to find.
|
// How many are above the current range, and how many we need to find.
|
||||||
uint total = 0;
|
uint total = 0;
|
||||||
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||||
|
|
||||||
while (mask != 0) {
|
while (mask != 0) {
|
||||||
barrier();
|
barrier();
|
||||||
|
|
@ -139,7 +140,7 @@ void topk(const uint row) {
|
||||||
range_max = range_min + ((min_idx + 1) << shift);
|
range_max = range_min + ((min_idx + 1) << shift);
|
||||||
range_min = range_min + (min_idx << shift);
|
range_min = range_min + (min_idx << shift);
|
||||||
|
|
||||||
if (total == p.ncols_output) {
|
if (total == p.k) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
total -= counts[min_idx];
|
total -= counts[min_idx];
|
||||||
|
|
@ -179,13 +180,17 @@ void topk(const uint row) {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
if (tid < p.k) {
|
||||||
if (p.last_pass != 0) {
|
if (p.last_pass != 0) {
|
||||||
const uint row_offset = row * p.ncols_output;
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
data_d[row_offset + tid] = dst_row[tid].x;
|
const uint row_offset = row * p.k;
|
||||||
|
data_d[row_offset + tid] = dst_row[tid].x;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
|
||||||
data_t[row_offset + tid] = dst_row[tid];
|
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
|
||||||
|
data_t[row_offset + tid] = dst_row[tid];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue