whisper: validate get_rows support for cpu extra buffer (#3323)

This commit is contained in:
Charles Xu 2025-07-14 14:13:44 +02:00 committed by GitHub
parent a16da91365
commit 032697b9a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 4 deletions

View File

@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
op_supported = true;
} else {
switch (op) {
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
case GGML_OP_GET_ROWS:
case GGML_OP_MUL_MAT: {
ggml_init_params params = {
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
ggml_tensor * op_tensor = nullptr;
int64_t n_ctx = hparams.n_audio_ctx;
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b);
if (op == GGML_OP_MUL_MAT) {
int64_t n_ctx = hparams.n_audio_ctx;
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b);
} else if (op == GGML_OP_GET_ROWS) {
int64_t num_indices = 8;
ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
op_tensor = ggml_get_rows(ctx, w, indices);
}
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
GGML_ASSERT(w->buffer == nullptr);