This commit is contained in:
haopeng 2026-04-20 09:12:36 +00:00 committed by GitHub
commit 7907d72763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 627 additions and 26 deletions

View File

@ -2,6 +2,8 @@
#include <regex>
#include <map>
#include <algorithm>
#include <cctype>
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
@ -16,6 +18,21 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
};
static const std::map<std::string, enum ggml_type> GGML_TYPE_MAP = {
{"q4_0", GGML_TYPE_Q4_0},
{"q4_1", GGML_TYPE_Q4_1},
{"q5_0", GGML_TYPE_Q5_0},
{"q5_1", GGML_TYPE_Q5_1},
{"q8_0", GGML_TYPE_Q8_0},
{"q2_k", GGML_TYPE_Q2_K},
{"q3_k", GGML_TYPE_Q3_K},
{"q4_k", GGML_TYPE_Q4_K},
{"q5_k", GGML_TYPE_Q5_K},
{"q6_k", GGML_TYPE_Q6_K},
{"f16", GGML_TYPE_F16},
{"f32", GGML_TYPE_F32},
};
void ggml_print_ftypes(FILE * fp) {
for (auto it = GGML_FTYPE_MAP.begin(); it != GGML_FTYPE_MAP.end(); it++) {
fprintf(fp, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
@ -38,6 +55,18 @@ enum ggml_ftype ggml_parse_ftype(const char * str) {
return ftype;
}
ggml_type ggml_parse_qtype(const char * str) {
std::string str_lower(str);
std::transform(str_lower.begin(), str_lower.end(), str_lower.begin(), ::tolower);
const auto it = GGML_TYPE_MAP.find(str_lower);
if (it == GGML_TYPE_MAP.end()) {
fprintf(stderr, "%s: unknown qtype '%s'\n", __func__, str);
return GGML_TYPE_COUNT;
}
return it->second;
}
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
@ -160,10 +189,13 @@ bool ggml_common_quantize_0(
ttype = qtype;
} else {
const int bpe = (ttype == 0) ? sizeof(float) : sizeof(uint16_t);
// For non-quantized tensors, we need to correctly calculate size based on type
// Use ggml_row_size to get the correct size for the tensor's row
const size_t row_size = ggml_row_size((ggml_type) ttype, ne[0]);
const size_t data_size = row_size * (nelements / ne[0]);
data_u8.resize(nelements*bpe);
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
data_u8.resize(data_size);
finp.read(reinterpret_cast<char *>(data_u8.data()), data_size);
}
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
@ -240,3 +272,246 @@ bool ggml_common_quantize_0(
return true;
}
// Extended quantization function with per-tensor quantization support
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip,
const std::vector<tensor_quant_spec> & tensor_quant_specs) {
ggml_type default_qtype = GGML_TYPE_F32;
switch (ftype) {
case GGML_FTYPE_MOSTLY_Q4_0: default_qtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: default_qtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q5_0: default_qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: default_qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: default_qtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: default_qtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: default_qtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: default_qtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: default_qtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: default_qtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_UNKNOWN:
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_IQ2_XXS:
case GGML_FTYPE_MOSTLY_IQ2_XS:
case GGML_FTYPE_MOSTLY_IQ2_S:
case GGML_FTYPE_MOSTLY_IQ3_XXS:
case GGML_FTYPE_MOSTLY_IQ3_S:
case GGML_FTYPE_MOSTLY_IQ1_S:
case GGML_FTYPE_MOSTLY_IQ4_NL:
case GGML_FTYPE_MOSTLY_IQ4_XS:
case GGML_FTYPE_MOSTLY_IQ1_M:
case GGML_FTYPE_MOSTLY_BF16:
case GGML_FTYPE_MOSTLY_MXFP4:
{
fprintf(stderr, "%s: unsupported model type %d (ftype=%d)\n", __func__, ftype, ftype);
return false;
}
};
if (!ggml_is_quantized(default_qtype)) {
fprintf(stderr, "%s: invalid quantization type %d (%s)\n", __func__, default_qtype, ggml_type_name(default_qtype));
return false;
}
// Pre-compile regex patterns for efficiency
struct compiled_pattern {
std::regex regex;
ggml_type quant_type;
};
std::vector<compiled_pattern> compiled_patterns;
compiled_patterns.reserve(tensor_quant_specs.size());
for (const auto & spec : tensor_quant_specs) {
try {
compiled_patterns.push_back({std::regex(spec.pattern), spec.quant_type});
} catch (const std::regex_error & e) {
fprintf(stderr, "%s: invalid regex pattern '%s': %s\n", __func__, spec.pattern.c_str(), e.what());
return false;
}
}
size_t total_size_org = 0;
size_t total_size_new = 0;
std::vector<float> work;
std::vector<uint8_t> data_u8;
std::vector<ggml_fp16_t> data_f16;
std::vector<float> data_f32;
std::unordered_map<std::string, int> quant_type_counts;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
finp.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (finp.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
finp.read (&name[0], length);
printf("%64s - [%5d, %5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype));
bool quantize = false;
ggml_type qtype = default_qtype;
// check if we should quantize this tensor
for (const auto & s : to_quant) {
if (std::regex_match(name, std::regex(s))) {
quantize = true;
break;
}
}
// check if we should skip this tensor
for (const auto & s : to_skip) {
if (std::regex_match(name, std::regex(s))) {
quantize = false;
break;
}
}
// check for per-tensor quantization specification
if (quantize) {
for (const auto & cp : compiled_patterns) {
if (std::regex_match(name, cp.regex)) {
qtype = cp.quant_type;
printf("matched pattern -> %s ", ggml_type_name(qtype));
break;
}
}
}
// quantize only 2D tensors
quantize &= (n_dims == 2);
if (quantize) {
if (ttype != GGML_TYPE_F32 && ttype != GGML_TYPE_F16) {
fprintf(stderr, "%s: unsupported ttype %d (%s) for integer quantization\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
if (ttype == GGML_TYPE_F16) {
data_f16.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
data_f32.resize(nelements);
for (int i = 0; i < nelements; ++i) {
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
}
} else {
data_f32.resize(nelements);
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
}
ttype = qtype;
quant_type_counts[ggml_type_name(qtype)]++;
} else {
// For non-quantized tensors, we need to correctly calculate size based on type
// Use ggml_row_size to get the correct size for the tensor's row
const size_t row_size = ggml_row_size((ggml_type) ttype, ne[0]);
const size_t data_size = row_size * (nelements / ne[0]);
data_u8.resize(data_size);
finp.read(reinterpret_cast<char *>(data_u8.data()), data_size);
}
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
fout.write(reinterpret_cast<char *>(&ttype), sizeof(ttype));
for (int i = 0; i < n_dims; ++i) {
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
}
fout.write(&name[0], length);
if (quantize) {
work.resize(nelements); // for quantization
size_t cur_size = 0;
switch ((ggml_type) ttype) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
{
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], nullptr);
} break;
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_BF16:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
return false;
}
}
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
total_size_new += cur_size;
printf("size = %8.2f MB -> %8.2f MB\n", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
} else {
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
total_size_new += data_u8.size();
}
total_size_org += nelements * sizeof(float);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(default_qtype));
printf("%s: quantization type summary:\n", __func__);
for (const auto & kv : quant_type_counts) {
printf("%s: %s: %d tensors\n", __func__, kv.first.c_str(), kv.second);
}
return true;
}

View File

@ -5,9 +5,19 @@
#include <fstream>
#include <vector>
#include <string>
#include <unordered_map>
// Structure for per-tensor quantization specification
struct tensor_quant_spec {
std::string pattern; // regex pattern to match tensor names
ggml_type quant_type; // quantization type for matched tensors
};
enum ggml_ftype ggml_parse_ftype(const char * str);
// Parse a quantization type string (e.g., "q4_0", "q8_0")
ggml_type ggml_parse_qtype(const char * str);
void ggml_print_ftypes(FILE * fp = stderr);
bool ggml_common_quantize_0(
@ -16,3 +26,12 @@ bool ggml_common_quantize_0(
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip);
// Extended quantization function with per-tensor quantization support
bool ggml_common_quantize_0(
std::ifstream & finp,
std::ofstream & fout,
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip,
const std::vector<tensor_quant_spec> & tensor_quant_specs);

View File

@ -1,3 +1,43 @@
# quantize
Tool for integer quantization of Whisper `ggml` model files
## Features
- Standard uniform quantization (Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q2_K, Q3_K, Q4_K, Q5_K, Q6_K)
- **Mixed precision quantization** - quantize different layers with different quantization types (NEW!)
## Basic Usage
```bash
./quantize model-f32.bin model-quant.bin type
```
Where `type` is one of: q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k
## Mixed Precision Quantization
You can now specify different quantization types for different tensors using the `--tensor-type` option:
```bash
./quantize [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin default_type
```
### Examples
**Quantize encoder with Q8_0 (higher quality) and decoder with Q4_0 (smaller size):**
```bash
./quantize \
--tensor-type 'encoder\..*\.weight'=q8_0 \
--tensor-type 'decoder\..*\.weight'=q4_0 \
model-f32.bin model-mixed.bin q4_k
```
**Keep attention layers at higher precision:**
```bash
./quantize \
--tensor-type '.*attn.*'=q8_0 \
model-f32.bin model-mixed.bin q4_0
```
For more detailed documentation and examples, see [README_MIXED_PRECISION.md](README_MIXED_PRECISION.md).

View File

@ -37,7 +37,11 @@ struct whisper_filters {
};
// quantize a model
static bool whisper_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
static bool whisper_model_quantize(
const std::string & fname_inp,
const std::string & fname_out,
ggml_ftype ftype,
const std::vector<tensor_quant_spec> & tensor_quant_specs = {}) {
gpt_vocab vocab;
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
@ -83,7 +87,12 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
// For mixed precision quantization, use F16 as the base ftype to ensure
// all tensor buffers are large enough to hold any quantization type
const bool use_mixed_precision = !tensor_quant_specs.empty();
const int32_t ftype_for_allocation = use_mixed_precision ? GGML_FTYPE_MOSTLY_F16 : ftype;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype_for_allocation;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
@ -99,6 +108,9 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
if (use_mixed_precision) {
fprintf(stderr, "%s: using mixed precision quantization (ftype for allocation = F16)\n", __func__);
}
fout.write((const char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((const char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
@ -165,7 +177,15 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
"decoder.positional_embedding",
};
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
// Use the extended quantization function if we have per-tensor specs
bool success;
if (!tensor_quant_specs.empty()) {
success = ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip, tensor_quant_specs);
} else {
success = ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip);
}
if (!success) {
fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
return false;
}
@ -179,12 +199,67 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
int main(int argc, char ** argv) {
ggml_backend_load_all();
if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
if (argc < 4) {
fprintf(stderr, "usage: %s [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin type\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, " --tensor-type PATTERN=TYPE : specify quantization type for tensors matching PATTERN\n");
fprintf(stderr, " PATTERN is a regex pattern to match tensor names\n");
fprintf(stderr, " TYPE is a quantization type (e.g., q4_0, q8_0, f16)\n");
fprintf(stderr, " Example: --tensor-type 'encoder\\..*\\.weight'=q8_0 --tensor-type 'decoder\\..*\\.weight'=q4_0\n");
fprintf(stderr, "\n");
ggml_print_ftypes(stderr);
return 1;
}
// Parse optional arguments
std::vector<tensor_quant_spec> tensor_quant_specs;
int arg_idx = 1;
while (arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0) {
if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
if (arg_idx + 1 >= argc) {
fprintf(stderr, "error: --tensor-type requires an argument\n");
return 1;
}
arg_idx++;
// Parse PATTERN=TYPE
const char * spec_str = argv[arg_idx];
const char * eq = strchr(spec_str, '=');
if (eq == nullptr) {
fprintf(stderr, "error: invalid --tensor-type format '%s', expected PATTERN=TYPE\n", spec_str);
return 1;
}
std::string pattern(spec_str, eq - spec_str);
std::string type_str(eq + 1);
ggml_type qtype = ggml_parse_qtype(type_str.c_str());
if (qtype == GGML_TYPE_COUNT) {
fprintf(stderr, "error: unknown quantization type '%s'\n", type_str.c_str());
return 1;
}
tensor_quant_spec spec;
spec.pattern = pattern;
spec.quant_type = qtype;
tensor_quant_specs.push_back(spec);
printf("Added tensor quantization spec: pattern='%s' type=%s\n",
pattern.c_str(), ggml_type_name(qtype));
} else {
fprintf(stderr, "error: unknown option '%s'\n", argv[arg_idx]);
return 1;
}
arg_idx++;
}
if (argc - arg_idx < 3) {
fprintf(stderr, "error: missing required arguments\n");
fprintf(stderr, "usage: %s [--tensor-type PATTERN=TYPE ...] model-f32.bin model-quant.bin type\n", argv[0]);
return 1;
}
// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL, false };
@ -192,10 +267,10 @@ int main(int argc, char ** argv) {
ggml_free(ctx);
}
const std::string fname_inp = argv[1];
const std::string fname_out = argv[2];
const std::string fname_inp = argv[arg_idx];
const std::string fname_out = argv[arg_idx + 1];
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
const ggml_ftype ftype = ggml_parse_ftype(argv[arg_idx + 2]);
const int64_t t_main_start_us = ggml_time_us();
@ -205,7 +280,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
if (!whisper_model_quantize(fname_inp, fname_out, ggml_ftype(ftype), tensor_quant_specs)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}

View File

@ -156,6 +156,12 @@ extern "C" {
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
// skip forward by offset bytes.
bool (*skip)(void * ctx, size_t offset);
// seek to absolute position in the file.
bool (*seek)(void * ctx, size_t offset);
// get current position in the file.
size_t (*tell)(void * ctx);
} whisper_model_loader;
// grammar element type

View File

@ -1684,6 +1684,71 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
// Pre-scan tensor metadata from file to determine actual types
// This allows us to allocate device memory with the correct sizes
struct tensor_meta {
ggml_type type;
int32_t ne[4];
};
std::map<std::string, tensor_meta> tensor_type_map;
size_t tensor_start_offset = 0; // file offset where tensor section begins
// If loader supports skip, seek, and tell, scan tensor metadata first (without loading data)
if (loader->skip && loader->seek && loader->tell) {
// Remember the current position (start of tensors section)
tensor_start_offset = loader->tell(loader->context);
while (true) {
int32_t n_dims;
read_safe(loader, n_dims);
// Check for EOF after reading the first field
if (loader->eof(loader->context)) {
break;
}
int32_t length;
int32_t ttype;
read_safe(loader, length);
read_safe(loader, ttype);
tensor_meta meta;
meta.type = ggml_type(ttype);
meta.ne[0] = 1;
meta.ne[1] = 1;
meta.ne[2] = 1;
meta.ne[3] = 1;
int32_t nelements = 1;
for (int i = 0; i < n_dims; ++i) {
read_safe(loader, meta.ne[i]);
nelements *= meta.ne[i];
}
std::string name;
std::vector<char> tmp(length);
loader->read(loader->context, &tmp[0], tmp.size());
name.assign(&tmp[0], tmp.size());
// Calculate tensor data size and skip it (without loading into memory)
const size_t tensor_data_size = ggml_row_size(meta.type, meta.ne[0]) * (nelements / meta.ne[0]);
if (!loader->skip(loader->context, tensor_data_size)) {
WHISPER_LOG_ERROR("%s: failed to skip tensor data for '%s'\n", __func__, name.c_str());
return false;
}
tensor_type_map[name] = meta;
}
// Seek back to the start of tensors section for the actual data loading later
if (!loader->seek(loader->context, tensor_start_offset)) {
WHISPER_LOG_ERROR("%s: failed to seek back to tensor data\n", __func__);
return false;
}
}
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
@ -1712,6 +1777,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
buft_list_t buft_list = make_buft_list(wctx.params);
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
// Get the tensor name
std::string tensor_name = format(ASR_TENSOR_NAMES.at(system).at(type), layer);
// If we pre-scanned tensor types, update meta tensor to use the actual type from file
auto it = tensor_type_map.find(tensor_name);
if (it != tensor_type_map.end()) {
const tensor_meta & file_meta = it->second;
if (meta->type != file_meta.type) {
// Update meta tensor type to match the file
meta->type = file_meta.type;
// Update strides based on new type
meta->nb[0] = ggml_type_size(meta->type);
meta->nb[1] = meta->nb[0] * (meta->ne[0] / ggml_blck_size(meta->type));
for (int i = 2; i < GGML_MAX_DIMS; i++) {
meta->nb[i] = meta->nb[i-1] * meta->ne[i-1];
}
}
}
ggml_op op = ASR_TENSOR_INFO.at(type);
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
if (!buft) {
@ -1721,7 +1805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
ggml_context * ctx = get_ctx(buft);
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
model.tensors[tensor_name] = tensor;
return tensor;
};
@ -1892,14 +1976,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
return false;
}
auto tensor = model.tensors[name.data()];
auto tensor = model.tensors[name];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
@ -1907,32 +1991,73 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
__func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
// Calculate size based on file's tensor type
const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]);
const size_t expected_tensor_size = ggml_nbytes(tensor);
// If we pre-scanned types, the tensor type should already match
// Otherwise (loader doesn't support seek), we need to handle type mismatch here
if (tensor->type != ggml_type(ttype)) {
// Type mismatch - this happens when loader doesn't support seek
// or when tensor wasn't found during pre-scan
if (!tensor_type_map.empty()) {
// We pre-scanned but types still don't match - this is unexpected
WHISPER_LOG_ERROR("%s: tensor '%s' type mismatch after pre-scan: expected %s, file has %s\n",
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
return false;
}
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
// Loader doesn't support seek - handle type mismatch at runtime (legacy path)
WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n",
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
// Check if the allocated buffer is large enough for the file's data
if (file_tensor_size > expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n",
__func__, name.c_str(), expected_tensor_size, ggml_type_name(tensor->type),
file_tensor_size, ggml_type_name(ggml_type(ttype)));
return false;
}
// Update tensor type to match the file
tensor->type = ggml_type(ttype);
// Update tensor strides (nb) based on new type
tensor->nb[0] = ggml_type_size(tensor->type);
tensor->nb[1] = tensor->nb[0] * (tensor->ne[0] / ggml_blck_size(tensor->type));
for (int i = 2; i < GGML_MAX_DIMS; i++) {
tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1];
}
} else {
// Types match, verify size
if (file_tensor_size != expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.c_str(), expected_tensor_size, file_tensor_size);
return false;
}
}
// Now read the data - use the file's size
const size_t bytes_to_read = file_tensor_size;
if (ggml_backend_buffer_is_host(tensor->buffer)) {
// for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
loader->read(loader->context, tensor->data, bytes_to_read);
BYTESWAP_TENSOR(tensor);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
read_buf.resize(bytes_to_read);
loader->read(loader->context, read_buf.data(), read_buf.size());
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
ggml_backend_tensor_set(tensor, read_buf.data(), 0, bytes_to_read);
}
total_size += ggml_nbytes(tensor);
total_size += bytes_to_read;
model.n_loaded++;
}
@ -3656,6 +3781,25 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char
fin->close();
};
loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};
loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear(); // clear any error flags
fin->seekg(offset, std::ios::beg);
return fin->good();
};
loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};
auto ctx = whisper_init_with_params_no_state(&loader, params);
if (ctx) {
@ -3699,6 +3843,29 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
loader.close = [](void * /*ctx*/) { };
loader.skip = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (buf->current_offset + offset > buf->size) {
return false;
}
buf->current_offset += offset;
return true;
};
loader.seek = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (offset > buf->size) {
return false;
}
buf->current_offset = offset;
return true;
};
loader.tell = [](void * ctx) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
return buf->current_offset;
};
return whisper_init_with_params_no_state(&loader, params);
}
@ -4749,6 +4916,25 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
fin->close();
};
loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};
loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear();
fin->seekg(offset, std::ios::beg);
return fin->good();
};
loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};
auto ctx = whisper_vad_init_with_params(&loader, params);
if (!ctx) {
whisper_vad_free(ctx);