Implement hooks for Parakeet
This commit is contained in:
parent
de505d23ad
commit
b3b9af63b2
|
|
@ -47,7 +47,7 @@ extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container
|
|||
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
|
||||
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
|
||||
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, parakeet_token_data *token_data);
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
|
|
@ -157,15 +157,146 @@ ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struc
|
|||
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
int progress;
|
||||
} call_parakeet_progress_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_progress_callback(void *v_args)
|
||||
{
|
||||
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_progress_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
progress,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
bool is_continued;
|
||||
} call_parakeet_encoder_begin_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_encoder_begin_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
|
||||
{
|
||||
return true;
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
call_parakeet_encoder_begin_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
true,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
|
||||
|
||||
return args.is_continued;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_parakeet_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_abort_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
VALUE cb;
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
|
|
@ -174,8 +305,21 @@ ruby_whisper_parakeet_abort_callback(void *user_data)
|
|||
ruby_whisper_parakeet_abort_callback_user_data *data = (ruby_whisper_parakeet_abort_callback_user_data *)user_data;
|
||||
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return is_interrupted == 1;
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_parakeet_abort_callbacks_args args = {
|
||||
data->callback_container,
|
||||
false,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
|
||||
|
||||
return args.is_interrupted;
|
||||
}
|
||||
|
||||
#define CALLBACK_CONTAINER_NAME(name) name ## _container
|
||||
|
|
@ -317,6 +461,14 @@ ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
|
|||
static VALUE
|
||||
ruby_whisper_parakeet_params_abort_on(VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetParams(self, rwpp);
|
||||
const VALUE blk = rb_block_proc();
|
||||
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
|
||||
rwpp->abort_callback_container->callbacks = rb_ary_new();
|
||||
}
|
||||
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ extern const rb_data_type_t ruby_whisper_type;
|
|||
|
||||
extern VALUE cSegment;
|
||||
|
||||
extern VALUE ruby_whisper_token_s_from_index(struct whisper_context *context, int i_segment, int index);
|
||||
extern VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int index);
|
||||
|
||||
static void
|
||||
rb_whisper_segment_mark(void *p)
|
||||
|
|
@ -190,7 +190,7 @@ ruby_whisper_segment_each_token(VALUE self)
|
|||
|
||||
const int n_tokens = whisper_full_n_tokens(rw->context, rws->index);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
rb_yield(ruby_whisper_token_s_from_index(rw->context, rws->index, i));
|
||||
rb_yield(ruby_whisper_token_s_init(rw->context, rws->index, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
|
|
|
|||
Loading…
Reference in New Issue