120 lines
3.5 KiB
C
120 lines
3.5 KiB
C
#include "ruby_whisper.h"
|
|
|
|
extern ID id_to_s;
|
|
extern ID id___method__;
|
|
extern ID id_to_enum;
|
|
|
|
extern VALUE cParakeetContext;
|
|
|
|
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
|
|
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
|
|
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
|
|
|
static void
|
|
ruby_whisper_parakeet_context_free(void *p)
|
|
{
|
|
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
|
if (rwpc->context) {
|
|
parakeet_free(rwpc->context);
|
|
rwpc->context = NULL;
|
|
}
|
|
}
|
|
|
|
static size_t
|
|
ruby_whisper_parakeet_context_memsize(const void *p)
|
|
{
|
|
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
|
if (!rwpc) {
|
|
return 0;
|
|
}
|
|
size_t size = sizeof(*rwpc);
|
|
return size;
|
|
}
|
|
|
|
const rb_data_type_t ruby_whisper_parakeet_context_type = {
|
|
"ruby_whisper_parakeet_context",
|
|
{0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,},
|
|
0, 0,
|
|
0
|
|
};
|
|
|
|
static VALUE
|
|
ruby_whisper_parakeet_context_allocate(VALUE klass)
|
|
{
|
|
ruby_whisper_parakeet_context *rwpc;
|
|
|
|
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
|
rwpc->context = NULL;
|
|
|
|
return obj;
|
|
}
|
|
|
|
typedef struct {
|
|
struct parakeet_context **context;
|
|
char *model_path;
|
|
} ruby_whisper_parakeet_context_init_args;
|
|
|
|
static void*
|
|
ruby_whisper_parakeet_context_init_without_gvl(void *args)
|
|
{
|
|
ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args;
|
|
*init_args->context = parakeet_init_from_file_with_params(init_args->model_path, parakeet_context_default_params());
|
|
return NULL;
|
|
}
|
|
|
|
static VALUE
|
|
ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
|
|
{
|
|
ruby_whisper_parakeet_context *rwpc;
|
|
VALUE model_path;
|
|
|
|
rb_scan_args(argc, argv, "1", &model_path);
|
|
TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
|
|
|
model_path = ruby_whisper_normalize_model_path(model_path);
|
|
if (!rb_respond_to(model_path, id_to_s)) {
|
|
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context");
|
|
}
|
|
ruby_whisper_parakeet_context_init_args init_args = {
|
|
&rwpc->context,
|
|
StringValueCStr(model_path),
|
|
};
|
|
rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL);
|
|
if (rwpc->context == NULL) {
|
|
rb_raise(rb_eRuntimeError, "Failed to load model");
|
|
}
|
|
|
|
return Qnil;
|
|
}
|
|
|
|
static VALUE
|
|
ruby_whisper_parakeet_context_each_segment(VALUE self)
|
|
{
|
|
if (!rb_block_given_p()) {
|
|
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
|
return rb_funcall(self, id_to_enum, 1, method_name);
|
|
}
|
|
|
|
ruby_whisper_parakeet_context *rwpc;
|
|
GetParakeetContext(self, rwpc);
|
|
|
|
const int n_segments = parakeet_full_n_segments(rwpc->context);
|
|
for (int i = 0; i < n_segments; ++i) {
|
|
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
|
|
}
|
|
|
|
return self;
|
|
}
|
|
|
|
void
|
|
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
|
|
{
|
|
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
|
|
|
|
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
|
|
|
|
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
|
|
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
|
|
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
|
|
}
|