Merge bdb042f754 into fc674574ca
This commit is contained in:
commit
89e247c6c7
|
|
@ -415,6 +415,9 @@ extern "C" {
|
|||
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
|
||||
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
|
||||
|
||||
WHISPER_API int whisper_get_lang_id_from_state(struct whisper_state * state);
|
||||
WHISPER_API float whisper_get_lang_prob_from_state(struct whisper_state * state);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
|
||||
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
|
||||
|
|
@ -465,6 +468,9 @@ extern "C" {
|
|||
// Progress callback
|
||||
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
|
||||
|
||||
// Detected language callback
|
||||
typedef void (*whisper_detected_language_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
|
|
@ -562,6 +568,10 @@ extern "C" {
|
|||
whisper_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
// called on detected language
|
||||
whisper_detected_language_callback detected_language_callback;
|
||||
void * detected_language_callback_user_data;
|
||||
|
||||
// called on each progress update
|
||||
whisper_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
|
|
|
|||
|
|
@ -892,6 +892,7 @@ struct whisper_state {
|
|||
std::vector<whisper_token> prompt_past1; // dynamic context from decoded output
|
||||
|
||||
int lang_id = 0; // english by default
|
||||
float lang_prob = 0.0f; // probability of the detected language
|
||||
|
||||
std::string path_model; // populated by whisper_init_from_file_with_params()
|
||||
|
||||
|
|
@ -4198,6 +4199,14 @@ float * whisper_get_logits_from_state(struct whisper_state * state) {
|
|||
return state->logits.data();
|
||||
}
|
||||
|
||||
int whisper_get_lang_id_from_state(struct whisper_state * state) {
|
||||
return state->lang_id;
|
||||
}
|
||||
|
||||
float whisper_get_lang_prob_from_state(struct whisper_state * state) {
|
||||
return state->lang_prob;
|
||||
}
|
||||
|
||||
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
||||
return ctx->vocab.id_to_token.at(token).c_str();
|
||||
}
|
||||
|
|
@ -5977,6 +5986,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||
/*.new_segment_callback =*/ nullptr,
|
||||
/*.new_segment_callback_user_data =*/ nullptr,
|
||||
|
||||
/*.detected_language_callback =*/ nullptr,
|
||||
/*.detected_language_callback_user_data =*/ nullptr,
|
||||
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
|
||||
|
|
@ -6818,9 +6830,15 @@ int whisper_full_with_state(
|
|||
return -3;
|
||||
}
|
||||
state->lang_id = lang_id;
|
||||
state->lang_prob = probs[lang_id];
|
||||
params.language = whisper_lang_str(lang_id);
|
||||
|
||||
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[lang_id]);
|
||||
|
||||
if (params.detected_language_callback) {
|
||||
params.detected_language_callback(ctx, state, params.detected_language_callback_user_data);
|
||||
}
|
||||
|
||||
if (params.detect_language) {
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue