diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 094f1fce9..7f643209a 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -24,14 +24,6 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; -typedef struct { - VALUE *context; - VALUE user_data; - VALUE callback; - VALUE callbacks; - bool is_interrupted; -} ruby_whisper_abort_callback_container; - typedef struct ruby_whisper_abort_callback_user_data { volatile rb_atomic_t is_interrupted; ruby_whisper_callback_container *callback_container; @@ -69,7 +61,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; @@ -111,6 +103,13 @@ typedef struct parsed_samples_t { bool memview_exported; } parsed_samples_t; +typedef struct full_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} full_args; + typedef struct { VALUE *context; VALUE *params; diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 0f2246dd9..c44f210a0 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; @@ -38,13 +38,6 @@ typedef struct fill_samples_args { int n_samples; } fill_samples_args; -typedef struct full_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; -} full_args; - typedef struct full_parallel_args { VALUE *context; VALUE *params; @@ -71,7 +64,7 @@ typedef struct full_parallel_without_gvl_args { } full_parallel_without_gvl_args; typedef struct full_ubf_args { - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_abort_callback_user_data *abort_callback_user_data; } full_ubf_args; static void @@ -480,10 +473,10 @@ full_ubf(void *rb_args) { full_ubf_args *args = (full_ubf_args *)rb_args; - args->abort_callback_container->is_interrupted = true; + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); } -static VALUE +VALUE full_body(VALUE rb_args) { full_args *args = (full_args *)rb_args; @@ -493,7 +486,11 @@ full_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, 1); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); struct full_without_gvl_args full_without_gvl_args = { rw->context, @@ -503,7 +500,7 @@ full_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_without_gvl_args.result); @@ -562,7 +559,11 @@ full_parallel_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, args->n_processors); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { rw->context, @@ -573,7 +574,7 @@ full_parallel_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_parallel_without_gvl_args.result); diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 7447cf5af..4468b4348 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -97,38 +97,11 @@ ruby_whisper_callback_container_allocate() { return container; } -static void -rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) -{ - if (rwc == NULL) return; - - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -static ruby_whisper_abort_callback_container* -rb_whisper_abort_callback_container_allocate() { - ruby_whisper_abort_callback_container *container; - container = ALLOC(ruby_whisper_abort_callback_container); - container->context = NULL; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = Qnil; - container->is_interrupted = false; - return container; -} - bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } -static bool -ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { - return !NIL_P(container->callback) || !NIL_P(container->callbacks); -} - typedef struct { const ruby_whisper_callback_container *container; struct whisper_state *state; @@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s } typedef struct { - const ruby_whisper_abort_callback_container *container; - struct whisper_state *state; + const ruby_whisper_callback_container *container; bool is_interrupted; } call_abort_callbacks_args; static void* call_abort_callbacks(void *v_args) { call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; - const ruby_whisper_abort_callback_container *container = args->container; - - if (container->is_interrupted) { - args->is_interrupted = true; - return NULL; - } + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) { if (NIL_P(container->callbacks)) { return NULL; } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { return NULL; } - for (int j = 0; j < callbacks_len; j++) { + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) { } static bool abort_callback(void * user_data) { - const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; - if (container->is_interrupted) { + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { return true; } - if (!ruby_whisper_abort_callback_container_is_present(container)) { + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { return false; } call_abort_callbacks_args args = { - container, - NULL, + data->callback_container, false }; rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); @@ -364,7 +332,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } @@ -374,7 +342,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) } } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -393,10 +361,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } + abort_callback_user_data->callback_container = rwp->abort_callback_container; rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; - rwp->abort_callback_container->is_interrupted = false; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -410,10 +378,10 @@ static void set_vad_params(ruby_whisper_params *rwp) TODO: Set abort callback to trap SIGINT and SIGTERM */ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { check_thread_safety(rwp, n_processors); - register_callbacks(rwp, context); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -424,7 +392,7 @@ rb_whisper_params_mark(void *p) ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); ruby_whisper_callback_container_mark(rwp->progress_callback_container); ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -495,7 +463,7 @@ ruby_whisper_params_allocate(VALUE klass) rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 37656af1c..9bd5dfbe4 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,6 +16,7 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); typedef struct{ struct whisper_context *context; @@ -35,18 +36,6 @@ transcribe_without_gvl(void *rb_args) return NULL; } -typedef struct { - ruby_whisper_abort_callback_container *abort_callback_container; -} transcribe_ubf_args; - -static void -transcribe_ubf(void *rb_args) -{ - transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; - - args->abort_callback_container->is_interrupted = true; -} - /* * transcribe a single file * can emit to a block results @@ -91,32 +80,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - - prepare_transcription(rwp, &self, n_processors); - - transcribe_without_gvl_args args = { - rw->context, - &rwp->params, + full_args args = { + &self, + ¶ms, pcmf32.data(), - pcmf32.size(), - n_processors, - 0, + (int)pcmf32.size(), }; - transcribe_ubf_args ubf_args = { - rwp->abort_callback_container, - }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); - if (args.result != 0) { + VALUE rb_result = full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; }