diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index e92605db..65c3d906 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -23,6 +23,8 @@ extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void ruby_whisper_lock_gvl(void); +extern void ruby_whisper_unlock_gvl(void); ID transcribe_option_names[1]; @@ -55,6 +57,15 @@ typedef struct full_without_gvl_args { int result; } full_without_gvl_args; +typedef struct full_parallel_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int n_processors; + int result; +} full_parallel_without_gvl_args; + typedef struct full_ubf_args { ruby_whisper_abort_callback_container *abort_callback_container; } full_ubf_args; @@ -443,6 +454,8 @@ release_samples(VALUE rb_parsed_args) static void* full_without_gvl(void *rb_args) { + ruby_whisper_unlock_gvl(); + full_without_gvl_args *args = (full_without_gvl_args *)rb_args; args->result = whisper_full(args->context, *args->params, args->samples, args->n_samples); return NULL; @@ -479,6 +492,7 @@ full_body(VALUE rb_args) rwp->abort_callback_container, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); + ruby_whisper_lock_gvl(); return INT2NUM(full_without_gvl_args.result); } @@ -517,6 +531,16 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static void* +full_parallel_without_gvl(void *rb_args) +{ + ruby_whisper_unlock_gvl(); + + full_parallel_without_gvl_args *args = (full_parallel_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + return NULL; +} + static VALUE full_parallel_body(VALUE rb_args) { @@ -528,9 +552,21 @@ full_parallel_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, args->n_processors); - int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); - return INT2NUM(result); + struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + args->n_processors, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); + ruby_whisper_lock_gvl(); + return INT2NUM(full_parallel_without_gvl_args.result); } /* diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 31e86f67..40935076 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -127,7 +127,7 @@ ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container } static bool -abort_ruby_whisper_callback_container_is_present(ruby_whisper_abort_callback_container *container) { +ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } @@ -344,6 +344,10 @@ static bool abort_callback(void * user_data) { return true; } + if (!ruby_whisper_abort_callback_container_is_present(container)) { + return false; + } + call_abort_callbacks_args args = { container, NULL, @@ -374,7 +378,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (abort_ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); }