Implement hooks for Parakeet

This commit is contained in:
Kitaiti Makoto 2026-05-20 16:47:21 +09:00
parent de505d23ad
commit b3b9af63b2
2 changed files with 157 additions and 5 deletions

View File

@ -47,7 +47,7 @@ extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, parakeet_token_data *token_data);
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef VALUE (*param_writer_t)(VALUE, VALUE);
@ -157,15 +157,146 @@ ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struc
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int progress;
} call_parakeet_progress_callbacks_args;
static void*
call_parakeet_progress_callback(void *v_args)
{
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
}
return NULL;
}
static void
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_progress_callbacks_args args = {
container,
state,
progress,
};
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
bool is_continued;
} call_parakeet_encoder_begin_callbacks_args;
static void*
call_parakeet_encoder_begin_callbacks(void *v_args)
{
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
return NULL;
}
static bool
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
{
return true;
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return true;
}
call_parakeet_encoder_begin_callbacks_args args = {
container,
state,
true,
};
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
return args.is_continued;
}
typedef struct {
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_parakeet_abort_callbacks_args;
static void*
call_parakeet_abort_callbacks(void *v_args)
{
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
VALUE cb;
for (long i = 0; i < n_callbacks; i++) {
cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
return NULL;
}
static bool
@ -174,8 +305,21 @@ ruby_whisper_parakeet_abort_callback(void *user_data)
ruby_whisper_parakeet_abort_callback_user_data *data = (ruby_whisper_parakeet_abort_callback_user_data *)user_data;
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
return is_interrupted == 1;
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_parakeet_abort_callbacks_args args = {
data->callback_container,
false,
};
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
return args.is_interrupted;
}
#define CALLBACK_CONTAINER_NAME(name) name ## _container
@ -317,6 +461,14 @@ ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
static VALUE
ruby_whisper_parakeet_params_abort_on(VALUE self)
{
ruby_whisper_parakeet_params *rwpp;
GetParakeetParams(self, rwpp);
const VALUE blk = rb_block_proc();
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
rwpp->abort_callback_container->callbacks = rb_ary_new();
}
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
return Qnil;
}

View File

@ -15,7 +15,7 @@ extern const rb_data_type_t ruby_whisper_type;
extern VALUE cSegment;
extern VALUE ruby_whisper_token_s_from_index(struct whisper_context *context, int i_segment, int index);
extern VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int index);
static void
rb_whisper_segment_mark(void *p)
@ -190,7 +190,7 @@ ruby_whisper_segment_each_token(VALUE self)
const int n_tokens = whisper_full_n_tokens(rw->context, rws->index);
for (int i = 0; i < n_tokens; ++i) {
rb_yield(ruby_whisper_token_s_from_index(rw->context, rws->index, i));
rb_yield(ruby_whisper_token_s_init(rw->context, rws->index, i));
}
return self;