Implement some hooks of Parakeet::Params

This commit is contained in:
Kitaiti Makoto 2026-05-19 23:03:55 +09:00
parent 10cf358ced
commit 94f327a67b
7 changed files with 182 additions and 24 deletions

View File

@ -36,6 +36,7 @@ typedef struct {
typedef struct ruby_whisper_parakeet_abort_callback_user_data {
volatile rb_atomic_t is_interrupted;
ruby_whisper_callback_container *callback_container;
} ruby_whisper_parakeet_abort_callback_user_data;
typedef struct ruby_whisper_log {

View File

@ -10,11 +10,18 @@
ITERATOR(left_context_ms, INT) \
ITERATOR(right_context_ms, INT)
#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \
ITERATOR(new_segment, DATA) \
ITERATOR(new_token, DATA) \
ITERATOR(progress, DATA) \
ITERATOR(encoder_begin, DATA)
#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback)
#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR)
#define ITERATE_CALLBACK_PARAMS(ITERATOR) \
ITERATOR(new_segment_callback) \
ITERATOR(new_token_callback) \
ITERATOR(progress_callback) \
ITERATOR(encoder_begin_callback) \
ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATOR(abort_callback)
enum {
@ -34,14 +41,133 @@ enum {
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
extern VALUE cParakeetParams;
extern ID id_call;
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, parakeet_token_data *token_data);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef VALUE (*param_writer_t)(VALUE, VALUE);
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int n_new;
} call_parakeet_new_segment_callbacks_args;
static void*
call_parakeet_new_segment_callbacks(void *v_args)
{
call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
const int n_segments = parakeet_full_n_segments_from_state(args->state);
for (int i = args->n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment);
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
return NULL;
}
static void
ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_segment_callbacks_args args = {
container,
state,
n_new,
};
rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_context *context;
struct parakeet_state *state;
const parakeet_token_data *token_data;
} call_parakeet_new_token_callbacks_args;
static void*
call_parakeet_new_token_callbacks(void *v_args)
{
call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args;
VALUE token = Qnil;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
if (NIL_P(token)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
}
for (int i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, token);
}
return NULL;
}
static void
ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_token_callbacks_args args = {
container,
context,
state,
token_data,
};
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
}
static void
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
{
}
static bool
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
{
return true;
}
static bool
ruby_whisper_parakeet_abort_callback(void *user_data)
{
@ -52,9 +178,25 @@ ruby_whisper_parakeet_abort_callback(void *user_data)
return is_interrupted == 1;
}
#define CALLBACK_CONTAINER_NAME(name) name ## _container
void
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data)
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data)
{
#define PARAM_NAME(name) name
#define USER_DATA_NAME(name) name##_user_data
#define REGISTER_CALLBACK(name) \
if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \
rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \
rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \
rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \
}
ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK)
if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) {
abort_callback_user_data->callback_container = rwpp->abort_callback_container;
}
rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback;
rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
@ -119,8 +261,6 @@ const rb_data_type_t ruby_whisper_parakeet_params_type = {
return val; \
}
#define CALLBACK_CONTAINER_NAME(name) name ## _container
#define DEF_CALLBACK_PARAM_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_params_get_##name(VALUE self) \
@ -155,24 +295,30 @@ const rb_data_type_t ruby_whisper_parakeet_params_type = {
return val; \
}
#define DEF_HOOK(name) \
#define DEF_HOOK(name, data) \
static VALUE \
ruby_whisper_parakeet_params_on_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
const VALUE blk = rb_block_proc(); \
if (!rwpp->name##_container->callbacks) { \
rwpp->name##_container->callbacks = rb_ary_new(); \
if (NIL_P(rwpp->name##_callback_container->callbacks)) { \
rwpp->name##_callback_container->callbacks = rb_ary_new(); \
} \
rb_ary_push(rwpp->name##_container->callbacks, blk); \
rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \
return Qnil; \
}
ITERATE_PARAMS(DEF_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_HOOK)
ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
static VALUE
ruby_whisper_parakeet_params_abort_on(VALUE self)
{
return Qnil;
}
static VALUE
ruby_whisper_parakeet_params_s_allocate(VALUE klass)
@ -240,8 +386,10 @@ init_ruby_whisper_parakeet_params(VALUE *mParakeet)
ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR)
#define REGISTER_HOOK(name) \
#define REGISTER_HOOK(name, data) \
rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0);
ITERATE_CALLBACK_PARAMS(REGISTER_HOOK)
ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _)
rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0);
}

View File

@ -33,7 +33,7 @@ extern VALUE sym_start_time;
extern VALUE sym_end_time;
extern VALUE sym_text;
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern VALUE ruby_whisper_parakeet_token_s_init(struct parakeet_context *context, int i_segment, int i_token);
extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token);
static void
rb_whisper_parakeet_segment_mark(void *p)
@ -96,7 +96,7 @@ ruby_whisper_parakeet_segment_each_token(VALUE self)
const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index);
for (int i = 0; i < n_tokens; i++) {
rb_yield(ruby_whisper_parakeet_token_s_init(rwpc->context, rwps->index, i));
rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i));
}
return self;

View File

@ -108,19 +108,27 @@ ruby_whisper_parakeet_token_s_allocate(VALUE klass)
}
VALUE
ruby_whisper_parakeet_token_s_init(struct parakeet_context *context, int i_segment, int i_token)
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
{
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
ruby_whisper_parakeet_token *rwpt;
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
*rwpt->token_data = parakeet_full_get_token_data(context, i_segment, i_token);
rwpt->text = rb_utf8_str_new_cstr(parakeet_full_get_token_text(context, i_segment, i_token));
*rwpt->token_data = *token_data;
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
{
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
}
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
// Define #text using parakeet_token_to_str or parakeet_token_to_text
ITERATE_ATTRS(DEF_ATTR)
static VALUE

View File

@ -10,7 +10,7 @@ extern "C" {
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern const rb_data_type_t ruby_whisper_parakeet_params_type;
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data);
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_parakeet_abort_callback_user_data *abort_callback_user_data);
extern ID id_to_s;
extern ID id_to_path;
@ -70,8 +70,9 @@ ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params)
ruby_whisper_parakeet_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
ruby_whisper_parakeet_prepare_transcription(rwpp, &abort_callback_user_data);
ruby_whisper_parakeet_prepare_transcription(rwpp, &self, &abort_callback_user_data);
struct transcribe_without_gvl_args args = {
rwpc->context,

View File

@ -119,7 +119,7 @@ rb_whisper_abort_callback_container_allocate() {
return container;
}
static bool
bool
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}

View File

@ -15,7 +15,7 @@ extern const rb_data_type_t ruby_whisper_type;
extern VALUE cSegment;
extern VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int index);
extern VALUE ruby_whisper_token_s_from_index(struct whisper_context *context, int i_segment, int index);
static void
rb_whisper_segment_mark(void *p)
@ -190,7 +190,7 @@ ruby_whisper_segment_each_token(VALUE self)
const int n_tokens = whisper_full_n_tokens(rw->context, rws->index);
for (int i = 0; i < n_tokens; ++i) {
rb_yield(ruby_whisper_token_s_init(rw->context, rws->index, i));
rb_yield(ruby_whisper_token_s_from_index(rw->context, rws->index, i));
}
return self;