diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index aacdbf204..d312bfb31 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "whisper.h" #include "parakeet.h" @@ -29,6 +30,10 @@ typedef struct { bool is_interrupted; } ruby_whisper_abort_callback_container; +typedef struct ruby_whisper_parakeet_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; +} ruby_whisper_parakeet_abort_callback_user_data; + typedef struct { struct whisper_context *context; } ruby_whisper; diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c index c8565a9d5..58288b907 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_params.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -44,6 +44,23 @@ static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; typedef VALUE (*param_writer_t)(VALUE, VALUE); static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_parakeet_abort_callback_user_data *data = (ruby_whisper_parakeet_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + + return is_interrupted == 1; +} + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data) +{ + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + static void ruby_whisper_parakeet_params_mark(void *p) { diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp index c5e9a412b..ea64a9215 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -10,8 +10,13 @@ extern "C" { extern const rb_data_type_t ruby_whisper_parakeet_context_type; extern const rb_data_type_t ruby_whisper_parakeet_params_type; +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data); + extern ID id_to_s; extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; static struct transcribe_without_gvl_args { struct parakeet_context *context; @@ -21,6 +26,18 @@ static struct transcribe_without_gvl_args { int result; } transcribe_without_gvl_args; +typedef struct { + ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data; +} ruby_whisper_parakeet_transcribe_ubf_args; + +static void +ruby_whisper_parakeet_transcribe_ubf(void *rb_args) +{ + ruby_whisper_parakeet_transcribe_ubf_args *args = (ruby_whisper_parakeet_transcribe_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + static void* transcribe_without_gvl(void *rb_args) { @@ -51,6 +68,11 @@ ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) GetParakeetContext(self, rwpc); GetParakeetParams(params, rwpp); + ruby_whisper_parakeet_abort_callback_user_data abort_callback_user_data = { + 0, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, &abort_callback_user_data); + struct transcribe_without_gvl_args args = { rwpc->context, rwpp->params, @@ -59,13 +81,16 @@ ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) 0, }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, NULL, NULL); - if (args.result != 0) { - rb_raise(rb_eRuntimeError, "Failed to process audio"); - return Qnil; - } + ruby_whisper_parakeet_transcribe_ubf_args ubf_args = { + &abort_callback_user_data, + }; - return self; + rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, ruby_whisper_parakeet_transcribe_ubf, (void *)&ubf_args); + if (args.result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, args.result)); + } } #ifdef __cplusplus