Refactor to make abort callback use the same way to parakeet's way

This commit is contained in:
Kitaiti Makoto 2026-06-12 17:03:27 +09:00
parent 2926922f64
commit 4f784eebe3
4 changed files with 55 additions and 114 deletions

View File

@ -24,14 +24,6 @@ typedef struct {
VALUE callbacks;
} ruby_whisper_callback_container;
typedef struct {
VALUE *context;
VALUE user_data;
VALUE callback;
VALUE callbacks;
bool is_interrupted;
} ruby_whisper_abort_callback_container;
typedef struct ruby_whisper_abort_callback_user_data {
volatile rb_atomic_t is_interrupted;
ruby_whisper_callback_container *callback_container;
@ -69,7 +61,7 @@ typedef struct {
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_callback_container *abort_callback_container;
VALUE vad_params;
} ruby_whisper_params;
@ -111,6 +103,13 @@ typedef struct parsed_samples_t {
bool memview_exported;
} parsed_samples_t;
typedef struct full_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} full_args;
typedef struct {
VALUE *context;
VALUE *params;

View File

@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
ID transcribe_option_names[1];
@ -38,13 +38,6 @@ typedef struct fill_samples_args {
int n_samples;
} fill_samples_args;
typedef struct full_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} full_args;
typedef struct full_parallel_args {
VALUE *context;
VALUE *params;
@ -71,7 +64,7 @@ typedef struct full_parallel_without_gvl_args {
} full_parallel_without_gvl_args;
typedef struct full_ubf_args {
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
} full_ubf_args;
static void
@ -480,10 +473,10 @@ full_ubf(void *rb_args)
{
full_ubf_args *args = (full_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
}
static VALUE
VALUE
full_body(VALUE rb_args)
{
full_args *args = (full_args *)rb_args;
@ -493,7 +486,11 @@ full_body(VALUE rb_args)
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, 1);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
struct full_without_gvl_args full_without_gvl_args = {
rw->context,
@ -503,7 +500,7 @@ full_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_without_gvl_args.result);
@ -562,7 +559,11 @@ full_parallel_body(VALUE rb_args)
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, args->n_processors);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
rw->context,
@ -573,7 +574,7 @@ full_parallel_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_parallel_without_gvl_args.result);

View File

@ -97,38 +97,11 @@ ruby_whisper_callback_container_allocate() {
return container;
}
static void
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
{
if (rwc == NULL) return;
rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_abort_callback_container*
rb_whisper_abort_callback_container_allocate() {
ruby_whisper_abort_callback_container *container;
container = ALLOC(ruby_whisper_abort_callback_container);
container->context = NULL;
container->user_data = Qnil;
container->callback = Qnil;
container->callbacks = Qnil;
container->is_interrupted = false;
return container;
}
bool
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
static bool
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct whisper_state *state;
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
}
typedef struct {
const ruby_whisper_abort_callback_container *container;
struct whisper_state *state;
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_abort_callbacks_args;
static void*
call_abort_callbacks(void *v_args) {
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
const ruby_whisper_abort_callback_container *container = args->container;
if (container->is_interrupted) {
args->is_interrupted = true;
return NULL;
}
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
if (NIL_P(container->callbacks)) {
return NULL;
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (0 == n_callbacks) {
return NULL;
}
for (int j = 0; j < callbacks_len; j++) {
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
VALUE result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
}
static bool abort_callback(void * user_data) {
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
if (container->is_interrupted) {
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
if (!ruby_whisper_abort_callback_container_is_present(container)) {
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_abort_callbacks_args args = {
container,
NULL,
data->callback_container,
false
};
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
@ -364,7 +332,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
}
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
}
@ -374,7 +342,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
}
}
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback;
@ -393,10 +361,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
abort_callback_user_data->callback_container = rwp->abort_callback_container;
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
rwp->abort_callback_container->is_interrupted = false;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
static void set_vad_params(ruby_whisper_params *rwp)
@ -410,10 +378,10 @@ static void set_vad_params(ruby_whisper_params *rwp)
TODO: Set abort callback to trap SIGINT and SIGTERM
*/
void
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
{
check_thread_safety(rwp, n_processors);
register_callbacks(rwp, context);
register_callbacks(rwp, context, abort_callback_user_data);
set_vad_params(rwp);
}
@ -424,7 +392,7 @@ rb_whisper_params_mark(void *p)
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
rb_gc_mark(rwp->vad_params);
}
@ -495,7 +463,7 @@ ruby_whisper_params_allocate(VALUE klass)
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
return obj;
}

View File

@ -16,6 +16,7 @@ extern ID id_to_path;
extern ID transcribe_option_names[1];
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
extern VALUE full_body(VALUE rb_args);
typedef struct{
struct whisper_context *context;
@ -35,18 +36,6 @@ transcribe_without_gvl(void *rb_args)
return NULL;
}
typedef struct {
ruby_whisper_abort_callback_container *abort_callback_container;
} transcribe_ubf_args;
static void
transcribe_ubf(void *rb_args)
{
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
}
/*
* transcribe a single file
* can emit to a block results
@ -91,32 +80,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
// Commented out because it is work in progress
// {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data;
// return !is_aborted;
// };
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
prepare_transcription(rwp, &self, n_processors);
transcribe_without_gvl_args args = {
rw->context,
&rwp->params,
full_args args = {
&self,
&params,
pcmf32.data(),
pcmf32.size(),
n_processors,
0,
(int)pcmf32.size(),
};
transcribe_ubf_args ubf_args = {
rwp->abort_callback_container,
};
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
if (args.result != 0) {
VALUE rb_result = full_body((VALUE)&args);
const int result = NUM2INT(rb_result);
if (result != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}