diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index a94797b6e..07e5e0f1e 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -87,6 +87,11 @@ typedef struct parsed_samples_t { typedef struct { struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; } ruby_whisper_parakeet_params; #define GetContext(obj, rw) do { \ diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c index 065aee1f8..a12d113e2 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_params.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -1,17 +1,27 @@ #include "ruby_whisper.h" #define ITERATE_PARAMS(ITERATOR) \ - ITERATOR(n_threads, INT) \ - ITERATOR(offset_ms, INT) \ - ITERATOR(duration_ms, INT) \ - ITERATOR(no_context, BOOL) \ - ITERATOR(audio_ctx, INT) \ - ITERATOR(chunk_length_ms, INT) \ - ITERATOR(left_context_ms, INT) \ - ITERATOR(right_context_ms, INT) + ITERATOR(n_threads, n_threads,INT) \ + ITERATOR(offset_ms, offset_ms, INT) \ + ITERATOR(duration_ms, duration_ms, INT) \ + ITERATOR(no_context, no_context, BOOL) \ + ITERATOR(audio_ctx, audio_ctx, INT) \ + ITERATOR(chunk_length_ms, chunk_length_ms, INT) \ + ITERATOR(left_context_ms, left_context_ms, INT) \ + ITERATOR(right_context_ms, right_context_ms, INT) \ + ITERATOR(new_segment_callback, new_segment_callback, CALLBACK) \ + ITERATOR(new_segment_callback_user_data, new_segment_callback, USER_DATA) \ + ITERATOR(new_token_callback, new_token_callback, CALLBACK) \ + ITERATOR(new_token_callback_user_data, new_token_callback, USER_DATA) \ + ITERATOR(progress_callback, progress_callback, CALLBACK) \ + ITERATOR(progress_callback_user_data, progress_callback, USER_DATA) \ + ITERATOR(encoder_begin_callback, encoder_begin_callback, CALLBACK) \ + ITERATOR(encoder_begin_callback_user_data, encoder_begin_callback, USER_DATA) \ + ITERATOR(abort_callback, abort_callback, CALLBACK) \ + ITERATOR(abort_callback_user_data, abort_callback, USER_DATA) enum { -#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX(name, cb, type) RUBY_WHISPER_PARAKEET_PARAM_##name, ITERATE_PARAMS(DEF_IDX) #undef DEF_IDX RUBY_WHISPER_PARAKEET_NUM_PARAMS @@ -22,6 +32,9 @@ enum { #define VAL_TO_BOOL(v) (RTEST(v)) #define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); + static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; typedef VALUE (*param_writer_t)(VALUE, VALUE); static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; @@ -29,11 +42,34 @@ static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; static void ruby_whisper_parakeet_params_mark(void *p) { + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + ruby_whisper_callback_container_mark(rwpp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwpp->new_token_callback_container); + ruby_whisper_callback_container_mark(rwpp->progress_callback_container); + ruby_whisper_callback_container_mark(rwpp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwpp->abort_callback_container); } static void ruby_whisper_parakeet_params_free(void *p) { + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + if (rwpp->params.new_segment_callback_user_data) { + xfree(&rwpp->params.new_segment_callback_user_data); + } + if (rwpp->params.new_token_callback_user_data) { + xfree(&rwpp->params.new_token_callback_user_data); + } + if (rwpp->params.progress_callback_user_data) { + xfree(&rwpp->params.progress_callback_user_data); + } + if (rwpp->params.encoder_begin_callback_user_data) { + xfree(&rwpp->params.encoder_begin_callback_user_data); + } + if (rwpp->params.abort_callback_user_data) { + xfree(&rwpp->params.abort_callback_user_data); + } + xfree(rwpp); } static size_t @@ -54,26 +90,80 @@ const rb_data_type_t ruby_whisper_parakeet_params_type = { 0 }; -#define DEF_PARAM_ATTR(name, type) \ +#define DEF_BOOL_PARAM_ATTR(name, cb) \ static VALUE \ ruby_whisper_parakeet_params_get_##name(VALUE self) \ { \ ruby_whisper_parakeet_params *rwpp; \ TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ - return VAL_FROM_##type(rwpp->params.name); \ + return VAL_FROM_BOOL(rwpp->params.name); \ } \ static VALUE \ ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ { \ ruby_whisper_parakeet_params *rwpp; \ TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ - rwpp->params.name = VAL_TO_##type(val); \ + rwpp->params.name = VAL_TO_BOOL(val); \ return val; \ } -ITERATE_PARAMS(DEF_PARAM_ATTR) +#define DEF_INT_PARAM_ATTR(name, cb) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + return VAL_FROM_INT(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + rwpp->params.name = VAL_TO_INT(val); \ + return val; \ + } -#undef DEF_PARAM_ATTR +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +#define DEF_CALLBACK_PARAM_ATTR(name, cb) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(cb)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(cb)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name, cb) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(cb)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(cb)->user_data = val; \ + return val; \ + } + +#define DEF_PARAM_ATTR(name, cb, type) DEF_PARAM_ATTR_I(name, cb, type) +#define DEF_PARAM_ATTR_I(name, cb, type) DEF_##type##_PARAM_ATTR(name, cb) + +ITERATE_PARAMS(DEF_PARAM_ATTR) static VALUE ruby_whisper_parakeet_params_s_allocate(VALUE klass) @@ -89,19 +179,28 @@ ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) { VALUE kw_hash; VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE id; VALUE value; ruby_whisper_parakeet_params *rwpp; int i; + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + + rwpp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwpp->new_token_callback_container = ruby_whisper_callback_container_allocate(); + rwpp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwpp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwpp->abort_callback_container = ruby_whisper_callback_container_allocate(); + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); if (NIL_P(kw_hash)) { return Qnil; } rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); - TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + id = param_names[i]; value = values[i]; if (value == Qundef) { continue; @@ -121,7 +220,7 @@ init_ruby_whisper_parakeet_params(VALUE *mParakeet) rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); int i = 0; -#define REGISTER_PARAM_ATTR(name, type) \ +#define REGISTER_PARAM_ATTR(name, cb, type) \ param_names[i] = rb_intern(#name); \ param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 2aae7c12d..a8632cd87 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; @@ -492,9 +492,9 @@ ruby_whisper_params_allocate(VALUE klass) } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); return obj; }