diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index b669e8197..70b637df6 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -11,6 +11,7 @@ VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; VALUE cParakeetContext; +VALUE cParakeetContextParams; VALUE cParakeetParams; VALUE cParakeetSegment; VALUE cParakeetModel; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 9fe3743e4..faed71190 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -126,6 +126,10 @@ typedef struct { ruby_whisper_callback_container *abort_callback_container; } ruby_whisper_parakeet_params; +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + typedef struct { struct parakeet_context *context; } ruby_whisper_parakeet_context; @@ -180,6 +184,10 @@ typedef struct { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + #define GetParakeetContext(obj, rwpc) do { \ TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ if ((rwpc)->context == NULL) { \ diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c index cd063fa9f..6d31cee62 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -17,7 +17,8 @@ extern ID id_join; extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); -extern void init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); @@ -93,7 +94,8 @@ init_ruby_whisper_parakeet(VALUE *mWhisper) init_ruby_whisper_parakeet_params(&mParakeet); init_ruby_whisper_parakeet_token(&mParakeet); init_ruby_whisper_parakeet_segment(&mParakeet); - init_ruby_whisper_parakeet_context(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); init_ruby_whisper_parakeet_model(&mParakeet); rb_include_module(cParakeetContext, mOutputContext); diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c index a7f0d7a75..e8fd1934a 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_context.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -263,7 +263,7 @@ ruby_whisper_parakeet_context_get_model(VALUE self) return ruby_whisper_parakeet_model_s_new(self); } -void +VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet) { cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); @@ -287,4 +287,6 @@ init_ruby_whisper_parakeet_context(VALUE *mParakeet) rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; } diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 000000000..38bd6d57c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index e14d6d98b..0ce85961d 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -701,6 +701,14 @@ module Whisper # def each_segment: { (Segment) -> void } -> void | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end end class Params diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 000000000..fcd0f2410 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end