Refactor to make abort callback use the same way to parakeet's way
This commit is contained in:
parent
2926922f64
commit
4f784eebe3
|
|
@ -24,14 +24,6 @@ typedef struct {
|
|||
VALUE callbacks;
|
||||
} ruby_whisper_callback_container;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE user_data;
|
||||
VALUE callback;
|
||||
VALUE callbacks;
|
||||
bool is_interrupted;
|
||||
} ruby_whisper_abort_callback_container;
|
||||
|
||||
typedef struct ruby_whisper_abort_callback_user_data {
|
||||
volatile rb_atomic_t is_interrupted;
|
||||
ruby_whisper_callback_container *callback_container;
|
||||
|
|
@ -69,7 +61,7 @@ typedef struct {
|
|||
ruby_whisper_callback_container *new_segment_callback_container;
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
VALUE vad_params;
|
||||
} ruby_whisper_params;
|
||||
|
||||
|
|
@ -111,6 +103,13 @@ typedef struct parsed_samples_t {
|
|||
bool memview_exported;
|
||||
} parsed_samples_t;
|
||||
|
||||
typedef struct full_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} full_args;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
|
|||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
|
||||
|
||||
ID transcribe_option_names[1];
|
||||
|
||||
|
|
@ -38,13 +38,6 @@ typedef struct fill_samples_args {
|
|||
int n_samples;
|
||||
} fill_samples_args;
|
||||
|
||||
typedef struct full_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} full_args;
|
||||
|
||||
typedef struct full_parallel_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
|
|
@ -71,7 +64,7 @@ typedef struct full_parallel_without_gvl_args {
|
|||
} full_parallel_without_gvl_args;
|
||||
|
||||
typedef struct full_ubf_args {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
|
||||
} full_ubf_args;
|
||||
|
||||
static void
|
||||
|
|
@ -480,10 +473,10 @@ full_ubf(void *rb_args)
|
|||
{
|
||||
full_ubf_args *args = (full_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
VALUE
|
||||
full_body(VALUE rb_args)
|
||||
{
|
||||
full_args *args = (full_args *)rb_args;
|
||||
|
|
@ -493,7 +486,11 @@ full_body(VALUE rb_args)
|
|||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, 1);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
|
||||
|
||||
struct full_without_gvl_args full_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -503,7 +500,7 @@ full_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_without_gvl_args.result);
|
||||
|
|
@ -562,7 +559,11 @@ full_parallel_body(VALUE rb_args)
|
|||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, args->n_processors);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
|
||||
|
||||
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -573,7 +574,7 @@ full_parallel_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_parallel_without_gvl_args.result);
|
||||
|
|
|
|||
|
|
@ -97,38 +97,11 @@ ruby_whisper_callback_container_allocate() {
|
|||
return container;
|
||||
}
|
||||
|
||||
static void
|
||||
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
rb_gc_mark(rwc->user_data);
|
||||
rb_gc_mark(rwc->callback);
|
||||
rb_gc_mark(rwc->callbacks);
|
||||
}
|
||||
|
||||
static ruby_whisper_abort_callback_container*
|
||||
rb_whisper_abort_callback_container_allocate() {
|
||||
ruby_whisper_abort_callback_container *container;
|
||||
container = ALLOC(ruby_whisper_abort_callback_container);
|
||||
container->context = NULL;
|
||||
container->user_data = Qnil;
|
||||
container->callback = Qnil;
|
||||
container->callbacks = Qnil;
|
||||
container->is_interrupted = false;
|
||||
return container;
|
||||
}
|
||||
|
||||
bool
|
||||
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
|
|
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
|
|||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_abort_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_abort_callbacks(void *v_args) {
|
||||
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_abort_callback_container *container = args->container;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
|
|||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (0 == n_callbacks) {
|
||||
return NULL;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
for (int j = 0; j < n_callbacks; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
VALUE result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
|
|||
}
|
||||
|
||||
static bool abort_callback(void * user_data) {
|
||||
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
|
||||
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!ruby_whisper_abort_callback_container_is_present(container)) {
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_abort_callbacks_args args = {
|
||||
container,
|
||||
NULL,
|
||||
data->callback_container,
|
||||
false
|
||||
};
|
||||
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
|
||||
|
|
@ -364,7 +332,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
|
|||
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
|
|
@ -374,7 +342,7 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
|
|||
}
|
||||
}
|
||||
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
|
||||
rwp->new_segment_callback_container->context = context;
|
||||
rwp->params.new_segment_callback = new_segment_callback;
|
||||
|
|
@ -393,10 +361,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
|||
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
|
||||
}
|
||||
|
||||
abort_callback_user_data->callback_container = rwp->abort_callback_container;
|
||||
rwp->abort_callback_container->context = context;
|
||||
rwp->params.abort_callback = abort_callback;
|
||||
rwp->abort_callback_container->is_interrupted = false;
|
||||
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
|
||||
}
|
||||
|
||||
static void set_vad_params(ruby_whisper_params *rwp)
|
||||
|
|
@ -410,10 +378,10 @@ static void set_vad_params(ruby_whisper_params *rwp)
|
|||
TODO: Set abort callback to trap SIGINT and SIGTERM
|
||||
*/
|
||||
void
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
|
||||
{
|
||||
check_thread_safety(rwp, n_processors);
|
||||
register_callbacks(rwp, context);
|
||||
register_callbacks(rwp, context, abort_callback_user_data);
|
||||
set_vad_params(rwp);
|
||||
}
|
||||
|
||||
|
|
@ -424,7 +392,7 @@ rb_whisper_params_mark(void *p)
|
|||
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
|
||||
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
|
||||
rb_gc_mark(rwp->vad_params);
|
||||
}
|
||||
|
||||
|
|
@ -495,7 +463,7 @@ ruby_whisper_params_allocate(VALUE klass)
|
|||
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
|
||||
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ extern ID id_to_path;
|
|||
extern ID transcribe_option_names[1];
|
||||
|
||||
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
|
||||
extern VALUE full_body(VALUE rb_args);
|
||||
|
||||
typedef struct{
|
||||
struct whisper_context *context;
|
||||
|
|
@ -35,18 +36,6 @@ transcribe_without_gvl(void *rb_args)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
} transcribe_ubf_args;
|
||||
|
||||
static void
|
||||
transcribe_ubf(void *rb_args)
|
||||
{
|
||||
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
|
|
@ -91,32 +80,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
// Commented out because it is work in progress
|
||||
// {
|
||||
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
// bool is_aborted = *(bool*)user_data;
|
||||
// return !is_aborted;
|
||||
// };
|
||||
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
// }
|
||||
|
||||
prepare_transcription(rwp, &self, n_processors);
|
||||
|
||||
transcribe_without_gvl_args args = {
|
||||
rw->context,
|
||||
&rwp->params,
|
||||
full_args args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
pcmf32.size(),
|
||||
n_processors,
|
||||
0,
|
||||
(int)pcmf32.size(),
|
||||
};
|
||||
transcribe_ubf_args ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
};
|
||||
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
|
||||
if (args.result != 0) {
|
||||
VALUE rb_result = full_body((VALUE)&args);
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue