diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c index e8fd1934a..bc39e9a55 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_context.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -32,6 +32,7 @@ extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); extern VALUE release_samples(VALUE rb_parsed_args); extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); @@ -78,13 +79,14 @@ ruby_whisper_parakeet_context_allocate(VALUE klass) typedef struct { struct parakeet_context **context; char *model_path; + struct parakeet_context_params params; } ruby_whisper_parakeet_context_init_args; static void* ruby_whisper_parakeet_context_init_without_gvl(void *args) { ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; - *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, parakeet_context_default_params()); + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); return NULL; } @@ -93,17 +95,27 @@ ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper_parakeet_context *rwpc; VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; - rb_scan_args(argc, argv, "1", &model_path); + rb_scan_args(argc, argv, "11", &model_path, &context_params); TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); model_path = ruby_whisper_normalize_model_path(model_path); if (!rb_respond_to(model_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } ruby_whisper_parakeet_context_init_args init_args = { &rwpc->context, StringValueCStr(model_path), + params, }; rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); if (rwpc->context == NULL) {