From 94f327a67b8f9018c03bc5034c18310d2772b7fd Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Tue, 19 May 2026 23:03:55 +0900 Subject: [PATCH] Implement some hooks of Parakeet::Params --- bindings/ruby/ext/ruby_whisper.h | 1 + .../ruby/ext/ruby_whisper_parakeet_params.c | 176 ++++++++++++++++-- .../ruby/ext/ruby_whisper_parakeet_segment.c | 4 +- .../ruby/ext/ruby_whisper_parakeet_token.c | 14 +- .../ext/ruby_whisper_parakeet_transcribe.cpp | 5 +- bindings/ruby/ext/ruby_whisper_params.c | 2 +- bindings/ruby/ext/ruby_whisper_segment.c | 4 +- 7 files changed, 182 insertions(+), 24 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index ee60e3747..c4839346e 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -36,6 +36,7 @@ typedef struct { typedef struct ruby_whisper_parakeet_abort_callback_user_data { volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; } ruby_whisper_parakeet_abort_callback_user_data; typedef struct ruby_whisper_log { diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c index 2d5f93079..740df507e 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_params.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -10,11 +10,18 @@ ITERATOR(left_context_ms, INT) \ ITERATOR(right_context_ms, INT) +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + #define ITERATE_CALLBACK_PARAMS(ITERATOR) \ - ITERATOR(new_segment_callback) \ - ITERATOR(new_token_callback) \ - ITERATOR(progress_callback) \ - ITERATOR(encoder_begin_callback) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ ITERATOR(abort_callback) enum { @@ -34,14 +41,133 @@ enum { #define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) extern VALUE cParakeetParams; +extern ID id_call; extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, parakeet_token_data *token_data); static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; typedef VALUE (*param_writer_t)(VALUE, VALUE); static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + return true; +} + static bool ruby_whisper_parakeet_abort_callback(void *user_data) { @@ -52,9 +178,25 @@ ruby_whisper_parakeet_abort_callback(void *user_data) return is_interrupted == 1; } +#define CALLBACK_CONTAINER_NAME(name) name ## _container + void -ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data) +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data) { +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; } @@ -119,8 +261,6 @@ const rb_data_type_t ruby_whisper_parakeet_params_type = { return val; \ } -#define CALLBACK_CONTAINER_NAME(name) name ## _container - #define DEF_CALLBACK_PARAM_ATTR(name) \ static VALUE \ ruby_whisper_parakeet_params_get_##name(VALUE self) \ @@ -155,24 +295,30 @@ const rb_data_type_t ruby_whisper_parakeet_params_type = { return val; \ } -#define DEF_HOOK(name) \ +#define DEF_HOOK(name, data) \ static VALUE \ ruby_whisper_parakeet_params_on_##name(VALUE self) \ { \ ruby_whisper_parakeet_params *rwpp; \ GetParakeetParams(self, rwpp); \ const VALUE blk = rb_block_proc(); \ - if (!rwpp->name##_container->callbacks) { \ - rwpp->name##_container->callbacks = rb_ary_new(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ } \ - rb_ary_push(rwpp->name##_container->callbacks, blk); \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ return Qnil; \ } ITERATE_PARAMS(DEF_PARAM_ATTR) ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) -ITERATE_CALLBACK_PARAMS(DEF_HOOK) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + return Qnil; +} static VALUE ruby_whisper_parakeet_params_s_allocate(VALUE klass) @@ -240,8 +386,10 @@ init_ruby_whisper_parakeet_params(VALUE *mParakeet) ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) -#define REGISTER_HOOK(name) \ +#define REGISTER_HOOK(name, data) \ rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); - ITERATE_CALLBACK_PARAMS(REGISTER_HOOK) + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); } diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c index 4479b6cda..b1e81ba93 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -33,7 +33,7 @@ extern VALUE sym_start_time; extern VALUE sym_end_time; extern VALUE sym_text; extern const rb_data_type_t ruby_whisper_parakeet_context_type; -extern VALUE ruby_whisper_parakeet_token_s_init(struct parakeet_context *context, int i_segment, int i_token); +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); static void rb_whisper_parakeet_segment_mark(void *p) @@ -96,7 +96,7 @@ ruby_whisper_parakeet_segment_each_token(VALUE self) const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); for (int i = 0; i < n_tokens; i++) { - rb_yield(ruby_whisper_parakeet_token_s_init(rwpc->context, rwps->index, i)); + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); } return self; diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c index 98c30de4d..e1da3e413 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_token.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -108,19 +108,27 @@ ruby_whisper_parakeet_token_s_allocate(VALUE klass) } VALUE -ruby_whisper_parakeet_token_s_init(struct parakeet_context *context, int i_segment, int i_token) +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) { const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); ruby_whisper_parakeet_token *rwpt; TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); - *rwpt->token_data = parakeet_full_get_token_data(context, i_segment, i_token); - rwpt->text = rb_utf8_str_new_cstr(parakeet_full_get_token_text(context, i_segment, i_token)); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); return token; } +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text ITERATE_ATTRS(DEF_ATTR) static VALUE diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp index 89244c77c..114395473 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -10,7 +10,7 @@ extern "C" { extern const rb_data_type_t ruby_whisper_parakeet_context_type; extern const rb_data_type_t ruby_whisper_parakeet_params_type; -extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data); extern ID id_to_s; extern ID id_to_path; @@ -70,8 +70,9 @@ ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) ruby_whisper_parakeet_abort_callback_user_data abort_callback_user_data = { 0, + NULL, }; - ruby_whisper_parakeet_prepare_transcription(rwpp, &abort_callback_user_data); + ruby_whisper_parakeet_prepare_transcription(rwpp, &self, &abort_callback_user_data); struct transcribe_without_gvl_args args = { rwpc->context, diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 7ec0ac107..7447cf5af 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -119,7 +119,7 @@ rb_whisper_abort_callback_container_allocate() { return container; } -static bool +bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index cf0372797..50974a6f3 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -15,7 +15,7 @@ extern const rb_data_type_t ruby_whisper_type; extern VALUE cSegment; -extern VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int index); +extern VALUE ruby_whisper_token_s_from_index(struct whisper_context *context, int i_segment, int index); static void rb_whisper_segment_mark(void *p) @@ -190,7 +190,7 @@ ruby_whisper_segment_each_token(VALUE self) const int n_tokens = whisper_full_n_tokens(rw->context, rws->index); for (int i = 0; i < n_tokens; ++i) { - rb_yield(ruby_whisper_token_s_init(rw->context, rws->index, i)); + rb_yield(ruby_whisper_token_s_from_index(rw->context, rws->index, i)); } return self;