Add Parakeet::Context::Params
This commit is contained in:
parent
ada170a786
commit
c6e854b384
|
|
@ -11,6 +11,7 @@ VALUE cVADParams;
|
|||
VALUE cVADSegments;
|
||||
VALUE cVADSegment;
|
||||
VALUE cParakeetContext;
|
||||
VALUE cParakeetContextParams;
|
||||
VALUE cParakeetParams;
|
||||
VALUE cParakeetSegment;
|
||||
VALUE cParakeetModel;
|
||||
|
|
|
|||
|
|
@ -126,6 +126,10 @@ typedef struct {
|
|||
ruby_whisper_callback_container *abort_callback_container;
|
||||
} ruby_whisper_parakeet_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context_params params;
|
||||
} ruby_whisper_parakeet_context_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context *context;
|
||||
} ruby_whisper_parakeet_context;
|
||||
|
|
@ -180,6 +184,10 @@ typedef struct {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContextParams(obj, rwpcp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContext(obj, rwpc) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \
|
||||
if ((rwpc)->context == NULL) { \
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ extern ID id_join;
|
|||
extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_context(VALUE *mParakeet);
|
||||
extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext);
|
||||
extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet);
|
||||
|
||||
extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue);
|
||||
|
|
@ -93,7 +94,8 @@ init_ruby_whisper_parakeet(VALUE *mWhisper)
|
|||
init_ruby_whisper_parakeet_params(&mParakeet);
|
||||
init_ruby_whisper_parakeet_token(&mParakeet);
|
||||
init_ruby_whisper_parakeet_segment(&mParakeet);
|
||||
init_ruby_whisper_parakeet_context(&mParakeet);
|
||||
cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet);
|
||||
init_ruby_whisper_parakeet_context_params(&cParakeetContext);
|
||||
init_ruby_whisper_parakeet_model(&mParakeet);
|
||||
|
||||
rb_include_module(cParakeetContext, mOutputContext);
|
||||
|
|
|
|||
|
|
@ -263,7 +263,7 @@ ruby_whisper_parakeet_context_get_model(VALUE self)
|
|||
return ruby_whisper_parakeet_model_s_new(self);
|
||||
}
|
||||
|
||||
void
|
||||
VALUE
|
||||
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
|
||||
|
|
@ -287,4 +287,6 @@ init_ruby_whisper_parakeet_context(VALUE *mParakeet)
|
|||
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2);
|
||||
|
||||
ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR)
|
||||
|
||||
return cParakeetContext;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(use_gpu, BOOL) \
|
||||
ITERATOR(gpu_device, INT)
|
||||
|
||||
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
|
||||
#define VAL_TO_BOOL(v) (RTEST(v))
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_TO_INT(v) (NUM2INT(v))
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define WRITER(type) VAL_TO_##type
|
||||
|
||||
#define DEF_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
return READER(type)(rwpcp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
rwpcp->params.name = WRITER(type)(val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name,
|
||||
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS
|
||||
};
|
||||
|
||||
extern VALUE cParakeetContextParams;
|
||||
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_context_params_memsize(const void *p)
|
||||
{
|
||||
if (!p) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_context_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_context_params_type = {
|
||||
"ruby_whisper_parakeet_context_params",
|
||||
{0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,},
|
||||
0, 0,
|
||||
0,
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
rwpcp->params = parakeet_context_default_params();
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values);
|
||||
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) {
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
param_writers[i](self, value);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext)
|
||||
{
|
||||
cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate);
|
||||
|
||||
rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1);
|
||||
|
||||
int i = 0;
|
||||
#define REGISTER_ATTR(name, type) \
|
||||
param_names[i] = rb_intern(#name); \
|
||||
param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \
|
||||
rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \
|
||||
rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \
|
||||
i++;
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
}
|
||||
|
|
@ -701,6 +701,14 @@ module Whisper
|
|||
#
|
||||
def each_segment: { (Segment) -> void } -> void
|
||||
| () -> Enumerator[Segment]
|
||||
|
||||
class Params
|
||||
def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance
|
||||
def use_gpu: () -> boolish
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
def gpu_device: () -> Integer
|
||||
def gpu_device=: (Integer) -> Integer
|
||||
end
|
||||
end
|
||||
|
||||
class Params
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetContextParams < TestBase
|
||||
def setup
|
||||
@params = Parakeet::Context::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Context::Params, @params
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
assert_true @params.use_gpu
|
||||
assert_instance_of Integer, @params.gpu_device
|
||||
end
|
||||
|
||||
def test_attribute_writer
|
||||
@params.use_gpu = false
|
||||
assert_false @params.use_gpu
|
||||
|
||||
@params.gpu_device = 2
|
||||
assert_equal 2, @params.gpu_device
|
||||
end
|
||||
end
|
||||
Loading…
Reference in New Issue