diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c index 740df507e..0063889b5 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_params.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -47,7 +47,7 @@ extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); -extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; typedef VALUE (*param_writer_t)(VALUE, VALUE); @@ -157,15 +157,146 @@ ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struc rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); } +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + static void ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; } static bool ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) { - return true; + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; } static bool @@ -174,8 +305,21 @@ ruby_whisper_parakeet_abort_callback(void *user_data) ruby_whisper_parakeet_abort_callback_user_data *data = (ruby_whisper_parakeet_abort_callback_user_data *)user_data; int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } - return is_interrupted == 1; + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; } #define CALLBACK_CONTAINER_NAME(name) name ## _container @@ -317,6 +461,14 @@ ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) static VALUE ruby_whisper_parakeet_params_abort_on(VALUE self) { + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 50974a6f3..cf0372797 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -15,7 +15,7 @@ extern const rb_data_type_t ruby_whisper_type; extern VALUE cSegment; -extern VALUE ruby_whisper_token_s_from_index(struct whisper_context *context, int i_segment, int index); +extern VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int index); static void rb_whisper_segment_mark(void *p) @@ -190,7 +190,7 @@ ruby_whisper_segment_each_token(VALUE self) const int n_tokens = whisper_full_n_tokens(rw->context, rws->index); for (int i = 0; i < n_tokens; ++i) { - rb_yield(ruby_whisper_token_s_from_index(rw->context, rws->index, i)); + rb_yield(ruby_whisper_token_s_init(rw->context, rws->index, i)); } return self;