Pre-scan tensor metadata from the file to determine actual data types

This commit is contained in:
lhpqaq 2026-01-19 01:01:15 +08:00
parent b5de98b430
commit e90e242ad4
2 changed files with 172 additions and 13 deletions

View File

@ -156,6 +156,12 @@ extern "C" {
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
// skip forward by offset bytes.
bool (*skip)(void * ctx, size_t offset);
// seek to absolute position in the file.
bool (*seek)(void * ctx, size_t offset);
// get current position in the file.
size_t (*tell)(void * ctx);
} whisper_model_loader;
// grammar element type

View File

@ -1684,6 +1684,71 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
// Pre-scan tensor metadata from file to determine actual types
// This allows us to allocate device memory with the correct sizes
struct tensor_meta {
ggml_type type;
int32_t ne[4];
};
std::map<std::string, tensor_meta> tensor_type_map;
size_t tensor_start_offset = 0; // file offset where tensor section begins
// If loader supports skip, seek, and tell, scan tensor metadata first (without loading data)
if (loader->skip && loader->seek && loader->tell) {
// Remember the current position (start of tensors section)
tensor_start_offset = loader->tell(loader->context);
while (true) {
int32_t n_dims;
read_safe(loader, n_dims);
// Check for EOF after reading the first field
if (loader->eof(loader->context)) {
break;
}
int32_t length;
int32_t ttype;
read_safe(loader, length);
read_safe(loader, ttype);
tensor_meta meta;
meta.type = ggml_type(ttype);
meta.ne[0] = 1;
meta.ne[1] = 1;
meta.ne[2] = 1;
meta.ne[3] = 1;
int32_t nelements = 1;
for (int i = 0; i < n_dims; ++i) {
read_safe(loader, meta.ne[i]);
nelements *= meta.ne[i];
}
std::string name;
std::vector<char> tmp(length);
loader->read(loader->context, &tmp[0], tmp.size());
name.assign(&tmp[0], tmp.size());
// Calculate tensor data size and skip it (without loading into memory)
const size_t tensor_data_size = ggml_row_size(meta.type, meta.ne[0]) * (nelements / meta.ne[0]);
if (!loader->skip(loader->context, tensor_data_size)) {
WHISPER_LOG_ERROR("%s: failed to skip tensor data for '%s'\n", __func__, name.c_str());
return false;
}
tensor_type_map[name] = meta;
}
// Seek back to the start of tensors section for the actual data loading later
if (!loader->seek(loader->context, tensor_start_offset)) {
WHISPER_LOG_ERROR("%s: failed to seek back to tensor data\n", __func__);
return false;
}
}
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
@ -1712,6 +1777,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
buft_list_t buft_list = make_buft_list(wctx.params);
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
// Get the tensor name
std::string tensor_name = format(ASR_TENSOR_NAMES.at(system).at(type), layer);
// If we pre-scanned tensor types, update meta tensor to use the actual type from file
auto it = tensor_type_map.find(tensor_name);
if (it != tensor_type_map.end()) {
const tensor_meta & file_meta = it->second;
if (meta->type != file_meta.type) {
// Update meta tensor type to match the file
meta->type = file_meta.type;
// Update strides based on new type
meta->nb[0] = ggml_type_size(meta->type);
meta->nb[1] = meta->nb[0] * (meta->ne[0] / ggml_blck_size(meta->type));
for (int i = 2; i < GGML_MAX_DIMS; i++) {
meta->nb[i] = meta->nb[i-1] * meta->ne[i-1];
}
}
}
ggml_op op = ASR_TENSOR_INFO.at(type);
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
if (!buft) {
@ -1721,7 +1805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
ggml_context * ctx = get_ctx(buft);
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
model.tensors[tensor_name] = tensor;
return tensor;
};
@ -1892,14 +1976,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
return false;
}
auto tensor = model.tensors[name.data()];
auto tensor = model.tensors[name];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
@ -1907,7 +1991,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
__func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
@ -1915,18 +1999,26 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]);
const size_t expected_tensor_size = ggml_nbytes(tensor);
// For mixed precision models, the tensor type in file may differ from the type
// the tensor was created with. We need to handle this carefully.
// If we pre-scanned types, the tensor type should already match
// Otherwise (loader doesn't support seek), we need to handle type mismatch here
if (tensor->type != ggml_type(ttype)) {
// Mixed precision: tensor created with one type, file has another
// We need to update the tensor's type to match the file
// Type mismatch - this happens when loader doesn't support seek
// or when tensor wasn't found during pre-scan
if (!tensor_type_map.empty()) {
// We pre-scanned but types still don't match - this is unexpected
WHISPER_LOG_ERROR("%s: tensor '%s' type mismatch after pre-scan: expected %s, file has %s\n",
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
return false;
}
// Loader doesn't support seek - handle type mismatch at runtime (legacy path)
WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n",
__func__, name.data(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
// Check if the allocated buffer is large enough for the file's data
if (file_tensor_size > expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n",
__func__, name.data(), expected_tensor_size, ggml_type_name(tensor->type),
__func__, name.c_str(), expected_tensor_size, ggml_type_name(tensor->type),
file_tensor_size, ggml_type_name(ggml_type(ttype)));
return false;
}
@ -1941,10 +2033,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1];
}
} else {
// Normal case: types match, verify size
// Types match, verify size
if (file_tensor_size != expected_tensor_size) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), expected_tensor_size, file_tensor_size);
__func__, name.c_str(), expected_tensor_size, file_tensor_size);
return false;
}
}
@ -3689,6 +3781,25 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char
fin->close();
};
loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};
loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear(); // clear any error flags
fin->seekg(offset, std::ios::beg);
return fin->good();
};
loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};
auto ctx = whisper_init_with_params_no_state(&loader, params);
if (ctx) {
@ -3732,6 +3843,29 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
loader.close = [](void * /*ctx*/) { };
loader.skip = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (buf->current_offset + offset > buf->size) {
return false;
}
buf->current_offset += offset;
return true;
};
loader.seek = [](void * ctx, size_t offset) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
if (offset > buf->size) {
return false;
}
buf->current_offset = offset;
return true;
};
loader.tell = [](void * ctx) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
return buf->current_offset;
};
return whisper_init_with_params_no_state(&loader, params);
}
@ -4782,6 +4916,25 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
fin->close();
};
loader.skip = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->seekg(offset, std::ios::cur);
return fin->good();
};
loader.seek = [](void * ctx, size_t offset) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->clear();
fin->seekg(offset, std::ios::beg);
return fin->good();
};
loader.tell = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
auto pos = fin->tellg();
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
};
auto ctx = whisper_vad_init_with_params(&loader, params);
if (!ctx) {
whisper_vad_free(ctx);