Pre-scan tensor metadata from the file to determine actual data types
This commit is contained in:
parent
b5de98b430
commit
e90e242ad4
|
|
@ -156,6 +156,12 @@ extern "C" {
|
|||
size_t (*read)(void * ctx, void * output, size_t read_size);
|
||||
bool (*eof)(void * ctx);
|
||||
void (*close)(void * ctx);
|
||||
// skip forward by offset bytes.
|
||||
bool (*skip)(void * ctx, size_t offset);
|
||||
// seek to absolute position in the file.
|
||||
bool (*seek)(void * ctx, size_t offset);
|
||||
// get current position in the file.
|
||||
size_t (*tell)(void * ctx);
|
||||
} whisper_model_loader;
|
||||
|
||||
// grammar element type
|
||||
|
|
|
|||
179
src/whisper.cpp
179
src/whisper.cpp
|
|
@ -1684,6 +1684,71 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
|
||||
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
||||
|
||||
// Pre-scan tensor metadata from file to determine actual types
|
||||
// This allows us to allocate device memory with the correct sizes
|
||||
struct tensor_meta {
|
||||
ggml_type type;
|
||||
int32_t ne[4];
|
||||
};
|
||||
std::map<std::string, tensor_meta> tensor_type_map;
|
||||
size_t tensor_start_offset = 0; // file offset where tensor section begins
|
||||
|
||||
// If loader supports skip, seek, and tell, scan tensor metadata first (without loading data)
|
||||
if (loader->skip && loader->seek && loader->tell) {
|
||||
// Remember the current position (start of tensors section)
|
||||
tensor_start_offset = loader->tell(loader->context);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
|
||||
read_safe(loader, n_dims);
|
||||
|
||||
// Check for EOF after reading the first field
|
||||
if (loader->eof(loader->context)) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t length;
|
||||
int32_t ttype;
|
||||
|
||||
read_safe(loader, length);
|
||||
read_safe(loader, ttype);
|
||||
|
||||
tensor_meta meta;
|
||||
meta.type = ggml_type(ttype);
|
||||
meta.ne[0] = 1;
|
||||
meta.ne[1] = 1;
|
||||
meta.ne[2] = 1;
|
||||
meta.ne[3] = 1;
|
||||
|
||||
int32_t nelements = 1;
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
read_safe(loader, meta.ne[i]);
|
||||
nelements *= meta.ne[i];
|
||||
}
|
||||
|
||||
std::string name;
|
||||
std::vector<char> tmp(length);
|
||||
loader->read(loader->context, &tmp[0], tmp.size());
|
||||
name.assign(&tmp[0], tmp.size());
|
||||
|
||||
// Calculate tensor data size and skip it (without loading into memory)
|
||||
const size_t tensor_data_size = ggml_row_size(meta.type, meta.ne[0]) * (nelements / meta.ne[0]);
|
||||
if (!loader->skip(loader->context, tensor_data_size)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to skip tensor data for '%s'\n", __func__, name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_type_map[name] = meta;
|
||||
}
|
||||
|
||||
// Seek back to the start of tensors section for the actual data loading later
|
||||
if (!loader->seek(loader->context, tensor_start_offset)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to seek back to tensor data\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
|
|
@ -1712,6 +1777,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
buft_list_t buft_list = make_buft_list(wctx.params);
|
||||
|
||||
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
|
||||
// Get the tensor name
|
||||
std::string tensor_name = format(ASR_TENSOR_NAMES.at(system).at(type), layer);
|
||||
|
||||
// If we pre-scanned tensor types, update meta tensor to use the actual type from file
|
||||
auto it = tensor_type_map.find(tensor_name);
|
||||
if (it != tensor_type_map.end()) {
|
||||
const tensor_meta & file_meta = it->second;
|
||||
if (meta->type != file_meta.type) {
|
||||
// Update meta tensor type to match the file
|
||||
meta->type = file_meta.type;
|
||||
// Update strides based on new type
|
||||
meta->nb[0] = ggml_type_size(meta->type);
|
||||
meta->nb[1] = meta->nb[0] * (meta->ne[0] / ggml_blck_size(meta->type));
|
||||
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
||||
meta->nb[i] = meta->nb[i-1] * meta->ne[i-1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_op op = ASR_TENSOR_INFO.at(type);
|
||||
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
||||
if (!buft) {
|
||||
|
|
@ -1721,7 +1805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
ggml_context * ctx = get_ctx(buft);
|
||||
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
||||
|
||||
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
||||
model.tensors[tensor_name] = tensor;
|
||||
|
||||
return tensor;
|
||||
};
|
||||
|
|
@ -1892,14 +1976,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
name.assign(&tmp[0], tmp.size());
|
||||
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
auto tensor = model.tensors[name];
|
||||
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
|
||||
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
||||
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
||||
return false;
|
||||
|
|
@ -1907,7 +1991,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
||||
__func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1915,18 +1999,26 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
const size_t file_tensor_size = ggml_row_size(ggml_type(ttype), ne[0]) * (nelements / ne[0]);
|
||||
const size_t expected_tensor_size = ggml_nbytes(tensor);
|
||||
|
||||
// For mixed precision models, the tensor type in file may differ from the type
|
||||
// the tensor was created with. We need to handle this carefully.
|
||||
// If we pre-scanned types, the tensor type should already match
|
||||
// Otherwise (loader doesn't support seek), we need to handle type mismatch here
|
||||
if (tensor->type != ggml_type(ttype)) {
|
||||
// Mixed precision: tensor created with one type, file has another
|
||||
// We need to update the tensor's type to match the file
|
||||
// Type mismatch - this happens when loader doesn't support seek
|
||||
// or when tensor wasn't found during pre-scan
|
||||
if (!tensor_type_map.empty()) {
|
||||
// We pre-scanned but types still don't match - this is unexpected
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' type mismatch after pre-scan: expected %s, file has %s\n",
|
||||
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Loader doesn't support seek - handle type mismatch at runtime (legacy path)
|
||||
WHISPER_LOG_DEBUG("%s: tensor '%s' type mismatch (expected %s, file has %s)\n",
|
||||
__func__, name.data(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
|
||||
__func__, name.c_str(), ggml_type_name(tensor->type), ggml_type_name(ggml_type(ttype)));
|
||||
|
||||
// Check if the allocated buffer is large enough for the file's data
|
||||
if (file_tensor_size > expected_tensor_size) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' buffer too small: allocated %zu bytes for %s, but file needs %zu bytes for %s\n",
|
||||
__func__, name.data(), expected_tensor_size, ggml_type_name(tensor->type),
|
||||
__func__, name.c_str(), expected_tensor_size, ggml_type_name(tensor->type),
|
||||
file_tensor_size, ggml_type_name(ggml_type(ttype)));
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1941,10 +2033,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||
tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1];
|
||||
}
|
||||
} else {
|
||||
// Normal case: types match, verify size
|
||||
// Types match, verify size
|
||||
if (file_tensor_size != expected_tensor_size) {
|
||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), expected_tensor_size, file_tensor_size);
|
||||
__func__, name.c_str(), expected_tensor_size, file_tensor_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -3689,6 +3781,25 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char
|
|||
fin->close();
|
||||
};
|
||||
|
||||
loader.skip = [](void * ctx, size_t offset) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
fin->seekg(offset, std::ios::cur);
|
||||
return fin->good();
|
||||
};
|
||||
|
||||
loader.seek = [](void * ctx, size_t offset) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
fin->clear(); // clear any error flags
|
||||
fin->seekg(offset, std::ios::beg);
|
||||
return fin->good();
|
||||
};
|
||||
|
||||
loader.tell = [](void * ctx) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
auto pos = fin->tellg();
|
||||
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
|
||||
};
|
||||
|
||||
auto ctx = whisper_init_with_params_no_state(&loader, params);
|
||||
|
||||
if (ctx) {
|
||||
|
|
@ -3732,6 +3843,29 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|||
|
||||
loader.close = [](void * /*ctx*/) { };
|
||||
|
||||
loader.skip = [](void * ctx, size_t offset) {
|
||||
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
||||
if (buf->current_offset + offset > buf->size) {
|
||||
return false;
|
||||
}
|
||||
buf->current_offset += offset;
|
||||
return true;
|
||||
};
|
||||
|
||||
loader.seek = [](void * ctx, size_t offset) {
|
||||
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
||||
if (offset > buf->size) {
|
||||
return false;
|
||||
}
|
||||
buf->current_offset = offset;
|
||||
return true;
|
||||
};
|
||||
|
||||
loader.tell = [](void * ctx) {
|
||||
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
||||
return buf->current_offset;
|
||||
};
|
||||
|
||||
return whisper_init_with_params_no_state(&loader, params);
|
||||
}
|
||||
|
||||
|
|
@ -4782,6 +4916,25 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
|
|||
fin->close();
|
||||
};
|
||||
|
||||
loader.skip = [](void * ctx, size_t offset) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
fin->seekg(offset, std::ios::cur);
|
||||
return fin->good();
|
||||
};
|
||||
|
||||
loader.seek = [](void * ctx, size_t offset) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
fin->clear();
|
||||
fin->seekg(offset, std::ios::beg);
|
||||
return fin->good();
|
||||
};
|
||||
|
||||
loader.tell = [](void * ctx) {
|
||||
std::ifstream * fin = (std::ifstream*)ctx;
|
||||
auto pos = fin->tellg();
|
||||
return (pos == std::streampos(-1)) ? SIZE_MAX : static_cast<size_t>(pos);
|
||||
};
|
||||
|
||||
auto ctx = whisper_vad_init_with_params(&loader, params);
|
||||
if (!ctx) {
|
||||
whisper_vad_free(ctx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue