This commit is contained in:
Seven Du 2026-04-20 13:57:15 +00:00 committed by GitHub
commit 89e247c6c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 1 deletions

View File

@ -415,6 +415,9 @@ extern "C" {
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
WHISPER_API int whisper_get_lang_id_from_state(struct whisper_state * state);
WHISPER_API float whisper_get_lang_prob_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
@ -465,6 +468,9 @@ extern "C" {
// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
// Detected language callback
typedef void (*whisper_detected_language_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
@ -562,6 +568,10 @@ extern "C" {
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called on detected language
whisper_detected_language_callback detected_language_callback;
void * detected_language_callback_user_data;
// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;

View File

@ -892,6 +892,7 @@ struct whisper_state {
std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
int lang_id = 0; // english by default
float lang_prob = 0.0f; // probability of the detected language
std::string path_model; // populated by whisper_init_from_file_with_params()
@ -4198,6 +4199,14 @@ float * whisper_get_logits_from_state(struct whisper_state * state) {
return state->logits.data();
}
int whisper_get_lang_id_from_state(struct whisper_state * state) {
return state->lang_id;
}
float whisper_get_lang_prob_from_state(struct whisper_state * state) {
return state->lang_prob;
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str();
}
@ -5977,6 +5986,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.detected_language_callback =*/ nullptr,
/*.detected_language_callback_user_data =*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
@ -6818,9 +6830,15 @@ int whisper_full_with_state(
return -3;
}
state->lang_id = lang_id;
state->lang_prob = probs[lang_id];
params.language = whisper_lang_str(lang_id);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[lang_id]);
if (params.detected_language_callback) {
params.detected_language_callback(ctx, state, params.detected_language_callback_user_data);
}
if (params.detect_language) {
return 0;
}