diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ece02732e..b669e8197 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -13,6 +13,7 @@ VALUE cVADSegment; VALUE cParakeetContext; VALUE cParakeetParams; VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 67d27ec23..9fe3743e4 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -140,6 +140,10 @@ typedef struct { VALUE text; } ruby_whisper_parakeet_token; +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -208,4 +212,11 @@ typedef struct { } \ } while (0) +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c index 9141ba67d..af0ae6585 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -18,6 +18,7 @@ extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); extern void init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); @@ -84,6 +85,7 @@ init_ruby_whisper_parakeet(VALUE *mWhisper) init_ruby_whisper_parakeet_token(&mParakeet); init_ruby_whisper_parakeet_segment(&mParakeet); init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_model(&mParakeet); rb_include_module(cParakeetContext, mOutputContext); rb_include_module(cParakeetSegment, mOutputSegment); diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c index cdd808d75..a7f0d7a75 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_context.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -33,6 +33,7 @@ extern VALUE release_samples(VALUE rb_parsed_args); extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); extern rb_data_type_t ruby_whisper_parakeet_params_type; extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); static void ruby_whisper_parakeet_context_free(void *p) @@ -256,6 +257,12 @@ ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) } } +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + void init_ruby_whisper_parakeet_context(VALUE *mParakeet) { @@ -267,6 +274,7 @@ init_ruby_whisper_parakeet_context(VALUE *mParakeet) rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 000000000..cc43ab917 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,75 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (rwpm->context) { + rb_gc_mark(rwpm->context); + } +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + GetParakeetModel(self, rwpm); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +}