diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 80a243e4c..8cdb7a810 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -27,6 +27,6 @@ jobs: steps: - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: - ruby-version: '3.2' + ruby-version: '3.3' - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 07b81830c..7f6b7d92c 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -396,6 +396,37 @@ whisper .full(Whisper::Params.new, samples) ``` +### Parakeet ### + +whispercpp gem now supports NVIDIA's ASR model Parakeet. + +If you want to use Parakeet instead of Whisper, the API should feel familiar. +In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way. + +```ruby +require "whisper" + +# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem. +Parakeet = Whisper::Parakeet + +parakeet = Parakeet::Context.new("path/to/model") + +params = Parakeet::Params.new( + no_context: true +) + +parakeet + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}" + end +``` + +The main differences are: + +* Namespace is `Whisper::Parakeet`. +* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks. + Custom context params --------------------- diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 7b521b3bd..2327651a0 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -84,6 +84,21 @@ else end end +TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin" +TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin")) +TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d") +directory TEST_PARAKEET_MODEL_DIR +if File.exist? TEST_PARAKEET_MODEL_SRC + file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t| + symlink t.source, t.name + end +else + require "open-uri" + file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t| + File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read + end +end + TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| chdir "test/jfk_reader" do @@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| end CLEAN.include TEST_MEMORY_VIEW -task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] +task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL] diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 56fceb1c8..7941b1a99 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,19 +1,29 @@ #include "ruby_whisper.h" VALUE mWhisper; +VALUE mLogSettable; VALUE mVAD; +VALUE mParakeet; VALUE cContext; VALUE cParams; VALUE cVADContext; VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; +VALUE cParakeetContext; +VALUE cParakeetContextParams; +VALUE cParakeetParams; +VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; VALUE cToken; VALUE cModel; +VALUE mOutputContext; +VALUE mOutputSegment; + ID id_to_s; ID id_call; ID id___method__; @@ -27,9 +37,11 @@ ID id_pre_converted_models; ID id_coreml_compiled_models; ID id_cache; ID id_n_processors; - -static bool is_log_callback_finalized = false; -static bool is_ruby_log_callback_present = false; +ID id_extended; +ID id_start_log_callback_thread; +ID id_log_callback_thread; +ID id_alive_p; +ID id_join; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void init_ruby_whisper_vad_context(VALUE *mVAD); extern void init_ruby_whisper_vad_segment(VALUE *mVAD); extern void init_ruby_whisper_vad_segments(VALUE *mVAD); +extern void init_ruby_whisper_parakeet(VALUE *mWhisper); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +static ruby_whisper_log_queue whisper_log_queue; + +LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set) + /* * call-seq: * lang_max_id -> Integer @@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) { return rb_str_new2(whisper_print_system_info()); } -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -typedef struct { - int level; - const char * buffer; -} call_log_callbacks_args; - -static void* -call_log_callbacks(void *v_args) { - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { - return NULL; - } - - call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; - VALUE user_data = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); - - return NULL; -} - -static void -ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - if (!is_ruby_log_callback_present) { - return; - } - - call_log_callbacks_args args = { - level, - buffer, - }; - if (ruby_thread_has_gvl_p()) { - call_log_callbacks((void *)&args); - } else { - rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); - } -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - if (!NIL_P(log_callback)) { - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - } - - if (NIL_P(log_callback)) { - whisper_log_set(NULL, NULL); - is_ruby_log_callback_present = false; - } else { - whisper_log_set(ruby_whisper_log_callback, NULL); - is_ruby_log_callback_present = true; - } - - return Qnil; -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -189,9 +133,19 @@ void Init_whisper() { id_coreml_compiled_models = rb_intern("coreml_compiled_models"); id_cache = rb_intern("cache"); id_n_processors = rb_intern("n_processors"); + id_extended = rb_intern("extended"); + id_start_log_callback_thread = rb_intern("start_log_callback_thread"); + id_log_callback_thread = rb_intern("@log_callback_thread"); + id_alive_p = rb_intern("alive?"); + id_join = rb_intern("join"); mWhisper = rb_define_module("Whisper"); + rb_require("whisper/log_settable"); + mLogSettable = rb_path2class("Whisper::LogSettable"); mVAD = rb_define_module_under(mWhisper, "VAD"); + rb_require("whisper/output"); + mOutputContext = rb_path2class("Whisper::Output::Context"); + mOutputSegment = rb_path2class("Whisper::Output::Segment"); rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version())); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); @@ -222,8 +176,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); + + LOG_SETTABLE_INIT(whisper_log_queue, mWhisper) cContext = init_ruby_whisper_context(&mWhisper); init_ruby_whisper_context_params(&cContext); @@ -236,8 +190,10 @@ void Init_whisper() { init_ruby_whisper_vad_segment(&mVAD); init_ruby_whisper_vad_segments(&mVAD); init_ruby_whisper_vad_context(&mVAD); + init_ruby_whisper_parakeet(&mWhisper); - rb_require("whisper/context"); - rb_require("whisper/segment"); rb_require("whisper/model/uri"); + + rb_include_module(cContext, mOutputContext); + rb_include_module(cSegment, mOutputSegment); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index ba4d8b6fb..10e906749 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,8 +5,12 @@ #include #include #include +#include +#include #include #include "whisper.h" +#include "parakeet.h" +#include "ruby_whisper_log_settable.h" #if RUBY_API_VERSION_MAJOR < 4 // Exists but not declared as public API @@ -20,13 +24,28 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; -typedef struct { - VALUE *context; - VALUE user_data; - VALUE callback; - VALUE callbacks; - bool is_interrupted; -} ruby_whisper_abort_callback_container; +typedef struct ruby_whisper_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; +} ruby_whisper_abort_callback_user_data; + +typedef struct ruby_whisper_log { + enum ggml_log_level level; + char *text; + size_t length; + size_t capacity; +} ruby_whisper_log; + +typedef struct ruby_whisper_log_queue { + rb_nativethread_lock_t lock; + rb_nativethread_cond_t cond; + bool is_open; + + size_t head; + size_t tail; + size_t size; + ruby_whisper_log *logs; +} ruby_whisper_log_queue; typedef struct { struct whisper_context *context; @@ -42,7 +61,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; @@ -84,6 +103,63 @@ typedef struct parsed_samples_t { bool memview_exported; } parsed_samples_t; +typedef struct { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} ruby_whisper_full_args; + +typedef struct ruby_whisper_full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} ruby_whisper_full_parallel_args; + +typedef struct { + struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; +} ruby_whisper_parakeet_params; + +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + +typedef struct { + struct parakeet_context *context; +} ruby_whisper_parakeet_context; + +typedef struct { + VALUE context; + int index; +} ruby_whisper_parakeet_segment; + +typedef struct { + parakeet_token_data *token_data; + VALUE text; +} ruby_whisper_parakeet_token; + +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + +extern ID id_extended; +extern ID id_log_callback_thread; +extern ID id_start_log_callback_thread; +extern ID id_alive_p; +extern ID id_join; +extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text); +extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue); + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -120,4 +196,47 @@ typedef struct parsed_samples_t { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetParams(obj, rwpp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ + if (!(rwpp)->new_segment_callback_container || \ + !(rwpp)->new_token_callback_container || \ + !(rwpp)->progress_callback_container || \ + !(rwpp)->encoder_begin_callback_container || \ + !(rwpp)->abort_callback_container) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetSegment(obj, rwps) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \ + if (!(rwps)->context) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetToken(obj, rwpt) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \ + if (!(rwpt)->token_data) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 26058fc07..9e5fc33e7 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; @@ -38,21 +38,6 @@ typedef struct fill_samples_args { int n_samples; } fill_samples_args; -typedef struct full_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; -} full_args; - -typedef struct full_parallel_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; - int n_processors; -} full_parallel_args; - typedef struct full_without_gvl_args { struct whisper_context *context; struct whisper_full_params *params; @@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args { } full_parallel_without_gvl_args; typedef struct full_ubf_args { - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_abort_callback_user_data *abort_callback_user_data; } full_ubf_args; static void @@ -379,7 +364,7 @@ fill_samples(VALUE rb_args) return Qnil; } -struct parsed_samples_t +parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples) { bool memview_available = rb_memory_view_available_p(*samples); @@ -480,20 +465,24 @@ full_ubf(void *rb_args) { full_ubf_args *args = (full_ubf_args *)rb_args; - args->abort_callback_container->is_interrupted = true; + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); } -static VALUE +VALUE full_body(VALUE rb_args) { - full_args *args = (full_args *)rb_args; + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, 1); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); struct full_without_gvl_args full_without_gvl_args = { rw->context, @@ -503,7 +492,7 @@ full_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_without_gvl_args.result); @@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) VALUE n_samples = argc == 2 ? Qnil : argv[2]; struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - full_args args = { + ruby_whisper_full_args args = { &self, &argv[0], parsed.samples, @@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args) return NULL; } -static VALUE +VALUE full_parallel_body(VALUE rb_args) { - full_parallel_args *args = (full_parallel_args *)rb_args; + ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, args->n_processors); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { rw->context, @@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_parallel_without_gvl_args.result); @@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) break; } struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - const full_parallel_args args = { + const ruby_whisper_full_parallel_args args = { &self, &argv[0], parsed.samples, diff --git a/bindings/ruby/ext/ruby_whisper_log_queue.c b/bindings/ruby/ext/ruby_whisper_log_queue.c new file mode 100644 index 000000000..6558a339c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_queue.c @@ -0,0 +1,180 @@ +#include "ruby_whisper.h" + +#define LOG_QUEUE_CAPACITY 256 +#define LOG_DEFAULT_CAPACITY 1024 + +void +ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_initialize(&log_queue->lock); + rb_native_cond_initialize(&log_queue->cond); + log_queue->head = 0; + log_queue->tail = 0; + log_queue->size = 0; + log_queue->is_open = true; + log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY); + for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) { + // we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL + // this doesn't be freed because log queue lives until the end of process + char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY); + if (!slot) { + rb_raise(rb_eRuntimeError, "Could not allocate memory for log text"); + } + ruby_whisper_log log = { + 0, + slot, + 0, + LOG_QUEUE_CAPACITY, + }; + log_queue->logs[i] = log; + } +} + +void +ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = true; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +void +ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = false; + rb_native_cond_broadcast(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static size_t +calc_enough_cap(size_t len) +{ + size_t quot = len / LOG_DEFAULT_CAPACITY; + size_t rem = len % LOG_DEFAULT_CAPACITY; + + return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY; +} + +void +ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + if (!log_queue->is_open) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + + size_t len = strlen(text); + ruby_whisper_log *log = &log_queue->logs[log_queue->head]; + if (len > log->capacity) { + size_t new_cap = calc_enough_cap(len); + // we cannot call Ruby API like REALLOC_N because this function is called without GVL + char *slot = realloc(log->text, new_cap); + if (!slot) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + log->text = slot; + log->capacity = new_cap; + } + // we cannot call Ruby API like MEMCPY because this function is called without GVL + memcpy(log->text, text, sizeof(char) * len); + log->length = len; + log->level = level; + log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY; + bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY; + log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1; + if (is_full) { + log_queue->tail = log_queue->head; + } + + rb_native_cond_signal(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static void* +ruby_whisper_log_queue_wait(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_wait(&log_queue->cond, &log_queue->lock); + rb_nativethread_lock_unlock(&log_queue->lock); + + return NULL; +} + +static void +ruby_whisper_log_queue_wait_ubf(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_broadcast(&log_queue->cond); +} + +typedef struct { + enum ggml_log_level level; + size_t length; + char *text; +} log_snapshot; + +VALUE +ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue) +{ + log_snapshot logs[LOG_QUEUE_CAPACITY]; + + rb_nativethread_lock_lock(&log_queue->lock); + + while (log_queue->size == 0 && log_queue->is_open) { + rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue); + rb_nativethread_lock_lock(&log_queue->lock); + } + + if (log_queue->size == 0 && !log_queue->is_open) { + rb_native_cond_broadcast(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); + return Qnil; + } + + size_t size = log_queue->size; + ruby_whisper_log *log; + size_t i; + for (i = 0; i < size; i++) { + log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY]; + logs[i].level = log->level; + logs[i].length = log->length; + char *text = malloc(log->length); + if (!text) { + logs[i].text = NULL; + continue; + } + logs[i].text = text; + memcpy(logs[i].text, log->text, log->length); + } + log_queue->size = 0; + log_queue->tail = log_queue->head; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); + + VALUE rb_logs = rb_ary_new2(size); + VALUE rb_text; + for (i = 0; i < size; i++) { + if (!logs[i].text) { + continue; + } + rb_text = rb_str_new(logs[i].text, logs[i].length); + free(logs[i].text); + rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text)); + } + + return rb_logs; +} diff --git a/bindings/ruby/ext/ruby_whisper_log_settable.h b/bindings/ruby/ext/ruby_whisper_log_settable.h new file mode 100644 index 000000000..b98fbac82 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_settable.h @@ -0,0 +1,47 @@ +#ifndef RUBY_WHISPER_LOG_SETTABLE_H +#define RUBY_WHISPER_LOG_SETTABLE_H + +#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \ + static VALUE \ + ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \ + { \ + return ruby_whisper_log_queue_drain(&log_queue); \ + } \ + static void \ + ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \ + { \ + ruby_whisper_log_queue_enqueue(&log_queue, level, text); \ + } \ + static VALUE \ + ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \ + { \ + rb_iv_set(self, "@log_callback", log_callback); \ + rb_iv_set(self, "@log_callback_user_data", user_data); \ + if (NIL_P(log_callback)) { \ + log_set(NULL, NULL); \ + } else { \ + ruby_whisper_log_queue_open(&log_queue); \ + rb_funcall((mod), id_start_log_callback_thread, 0); \ + log_set(ruby_whisper_##log_queue##_log_callback, NULL); \ + } \ + return Qnil; \ + } \ + static void \ + ruby_whisper_##log_queue##_end_proc(VALUE args) \ + { \ + ruby_whisper_log_queue_close(&log_queue); \ + VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \ + if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \ + rb_funcall(log_callback_thread, id_join, 0); \ + } \ + } + +#define LOG_SETTABLE_INIT(log_queue, mod) \ + ruby_whisper_log_queue_initialize(&log_queue); \ + rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \ + rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \ + rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \ + rb_extend_object(mod, mLogSettable); \ + rb_funcall(mLogSettable, id_extended, 1, mod); + +#endif diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c new file mode 100644 index 000000000..d69369401 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -0,0 +1,49 @@ +#include "ruby_whisper.h" +#include +#include + +extern VALUE mParakeet; +extern VALUE mLogSettable; +extern VALUE cParakeetContext; +extern VALUE cParakeetSegment; +extern VALUE mOutputContext; +extern VALUE mOutputSegment; + +extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); + +static ruby_whisper_log_queue parakeet_log_queue; + +LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set) + +static VALUE +ruby_whisper_parakeet_s_system_info_str(VALUE self) +{ + return rb_str_new2(parakeet_print_system_info()); +} + +void +init_ruby_whisper_parakeet(VALUE *mWhisper) +{ + mParakeet = rb_define_module_under(*mWhisper, "Parakeet"); + + rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version())); + + LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet) + + rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0); + + init_ruby_whisper_parakeet_params(&mParakeet); + init_ruby_whisper_parakeet_token(&mParakeet); + init_ruby_whisper_parakeet_segment(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); + init_ruby_whisper_parakeet_model(&mParakeet); + + rb_include_module(cParakeetContext, mOutputContext); + rb_include_module(cParakeetSegment, mOutputSegment); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c new file mode 100644 index 000000000..b4a2fc5c4 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -0,0 +1,304 @@ +#include "ruby_whisper.h" + +#define ITERATE_SEGMENT_ATTRS(ITERATOR) \ + ITERATOR(get_segment_t0, LONG) \ + ITERATOR(get_segment_t1, LONG) \ + ITERATOR(get_segment_text, STRING) \ + ITERATOR(n_tokens, INT) + +#define ITERATE_TOKEN_ATTRS(ITERATOR) \ + ITERATOR(get_token_text, STRING) \ + ITERATOR(get_token_id, INT) \ + ITERATOR(get_token_p, FLOAT) + +#define VAL_FROM_LONG(v) LONG2NUM(v) +#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v) +#define VAL_FROM_INT(v) INT2NUM(v) +#define VAL_FROM_FLOAT(v) DBL2NUM(v) +#define READER(type) VAL_FROM_##type + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_new; + +extern VALUE cParakeetContext; +extern VALUE eError; + +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE rb_parsed_args); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); +extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); + +static void +ruby_whisper_parakeet_context_free(void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (rwpc->context) { + parakeet_free(rwpc->context); + rwpc->context = NULL; + } + xfree(rwpc); +} + +static size_t +ruby_whisper_parakeet_context_memsize(const void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (!rwpc) { + return 0; + } + size_t size = sizeof(*rwpc); + return size; +} + +const rb_data_type_t ruby_whisper_parakeet_context_type = { + "ruby_whisper_parakeet_context", + {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_context_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context *rwpc; + + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + rwpc->context = NULL; + + return obj; +} + +typedef struct { + struct parakeet_context **context; + char *model_path; + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_init_args; + +static void* +ruby_whisper_parakeet_context_init_without_gvl(void *args) +{ + ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); + return NULL; +} + +static VALUE +ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; + + rb_scan_args(argc, argv, "11", &model_path, &context_params); + TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + + model_path = ruby_whisper_normalize_model_path(model_path); + if (!rb_respond_to(model_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); + } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } + ruby_whisper_parakeet_context_init_args init_args = { + &rwpc->context, + StringValueCStr(model_path), + params, + }; + rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); + if (rwpc->context == NULL) { + rb_raise(rb_eRuntimeError, "Failed to load model"); + } + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_context_full_n_segments(VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + return INT2NUM(parakeet_full_n_segments(rwpc->context)); +} + +#define DEF_SEGMENT_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \ + } + +ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR) + +#define DEF_TOKEN_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \ + } + +ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR) + +static VALUE +ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token)); + + return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data); +} + +static VALUE +ruby_whisper_parakeet_context_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + const int n_segments = parakeet_full_n_segments(rwpc->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(ruby_whisper_parakeet_segment_init(self, i)); + } + + return self; +} + +typedef struct { + struct parakeet_context *context; + struct parakeet_full_params *params; + float *samples; + int n_samples; + int result; +} parakeet_full_without_gvl_args; + +static void* +parakeet_full_without_gvl(void *rb_args) +{ + parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args; + args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} parakeet_full_ubf_args; + +static void +parakeet_full_ubf(void *rb_args) +{ + parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +ruby_whisper_parakeet_context_full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(*args->context, rwpc); + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(*args->params, rwpp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data); + + parakeet_full_without_gvl_args full_without_gvl_args = { + rwpc->context, + &rwpp->params, + args->samples, + args->n_samples, + 0 + }; + parakeet_full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args); + + return INT2NUM(full_without_gvl_args.result); +} + +static VALUE +ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + +VALUE +init_ruby_whisper_parakeet_context(VALUE *mParakeet) +{ + cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); + + rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); + + rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); + rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); + rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); + rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); + rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); + +#define REGISTER_SEGMENT_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1); + + ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR) + +#define REGISTER_TOKEN_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); + + ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 000000000..38bd6d57c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 000000000..dce43c688 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,84 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (!NIL_P(rwpm->context)) { + rb_gc_mark(rwpm->context); + } +} + +static size_t +ruby_whisper_parakeet_model_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_model); +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetModel(self, rwpm); \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c new file mode 100644 index 000000000..076e2a0cd --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -0,0 +1,548 @@ +#include "ruby_whisper.h" + +#define ITERATE_PARAMS(ITERATOR) \ + ITERATOR(n_threads, INT) \ + ITERATOR(offset_ms, INT) \ + ITERATOR(duration_ms, INT) \ + ITERATOR(no_context, BOOL) \ + ITERATOR(audio_ctx, INT) + +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + +#define ITERATE_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATOR(abort_callback) + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data, + ITERATE_PARAMS(DEF_IDX) + ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK) + ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA) + + RUBY_WHISPER_PARAKEET_NUM_PARAMS +}; + +#define VAL_TO_INT(v) (NUM2INT(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) + +extern VALUE cParakeetParams; +extern ID id_call; + +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef VALUE (*param_writer_t)(VALUE, VALUE); +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; +} + +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) +{ +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + +static void +ruby_whisper_parakeet_params_mark(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define MARK_CONTAINER(name) \ + if (rwpp->name##_container) { \ + ruby_whisper_callback_container_mark(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(MARK_CONTAINER) +} + +static void +ruby_whisper_parakeet_params_free(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define FREE_CONTAINER(name) \ + if (rwpp->name##_container) { \ + xfree(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(FREE_CONTAINER) + + xfree(rwpp); +} + +static size_t +ruby_whisper_parakeet_params_memsize(const void *p) +{ + const struct ruby_whisper_parakeet_params *params = p; + if (!params) { + return 0; + } + return sizeof(ruby_whisper_parakeet_params); +} + +const rb_data_type_t ruby_whisper_parakeet_params_type = { + "ruby_whisper_parakeet_params", + {ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,}, + 0, 0, + 0 +}; + +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type +#define DEF_PARAM_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return READER(type)(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->params.name = WRITER(type)(val); \ + return val; \ + } + +#define DEF_CALLBACK_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \ + return val; \ + } + +#define DEF_HOOK(name, data) \ + static VALUE \ + ruby_whisper_parakeet_params_on_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + const VALUE blk = rb_block_proc(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ + } \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ + return Qnil; \ + } + +ITERATE_PARAMS(DEF_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_params *rwpp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + return obj; +} + +static VALUE +ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_params *rwpp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + +#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate(); + + ITERATE_CALLBACK_PARAMS(INIT_CONTAINER) + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); + + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +void +init_ruby_whisper_parakeet_params(VALUE *mParakeet) +{ + cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject); + rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate); + + rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); + + int i = 0; +#define REGISTER_PARAM(name) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ + rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ + rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \ + i++; + +#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name) +#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name) +#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data) + + ITERATE_PARAMS(REGISTER_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) + +#define REGISTER_HOOK(name, data) \ + rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); + + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c new file mode 100644 index 000000000..b1e81ba93 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -0,0 +1,157 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(start_time, t0, TIME) \ + ITERATOR(end_time, t1, TIME) \ + ITERATOR(text, text, STRING) + +enum { +#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, +}; + +#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10)) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) +#define READER(type) VAL_FROM_##type +#define DEF_ATTR(rb_name, c_name, type) \ + static VALUE \ + ruby_whisper_parakeet_get_##rb_name(VALUE self) \ + { \ + ruby_whisper_parakeet_segment *rwps; \ + GetParakeetSegment(self, rwps); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwps->context, rwpc); \ + return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \ + } + +extern ID id___method__; +extern ID id_to_enum; +extern VALUE cParakeetSegment; +extern VALUE sym_start_time; +extern VALUE sym_end_time; +extern VALUE sym_text; +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); + +static void +rb_whisper_parakeet_segment_mark(void *p) +{ + ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p; + rb_gc_mark(rwps->context); +} + +static size_t +ruby_whisper_parakeet_segment_memsize(const void *p) +{ + const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p; + if (!rwps) { + return 0; + } + return sizeof(*rwps); +} + +static const rb_data_type_t ruby_whisper_parakeet_segment_type = { + "ruby_whisper_parakeet_segment", + {rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_segment_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_segment *rwps; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); +} + +VALUE +ruby_whisper_parakeet_segment_init(VALUE context, int index) +{ + ruby_whisper_parakeet_segment *rwps; + + const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment); + TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); + rwps->context = context; + rwps->index = index; + + return segment; +} + +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_segment_each_token(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); + for (int i = 0; i < n_tokens; i++) { + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); + } + + return self; +} + +static VALUE +ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + VALUE hash = rb_hash_new(); + long n_keys; + if (NIL_P(keys)) { + keys = rb_ary_new3( + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, + sym_start_time, + sym_end_time, + sym_text + ); + n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) { + return hash; + } + } + for (int i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, c_name, type) \ + if (key == sym_##rb_name) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \ + } + + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_segment(VALUE *mParakeet) +{ + cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject); + + rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate); + +#define REGISTER_ATTR(rb_name, c_name, type) \ + rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0); + rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c new file mode 100644 index 000000000..a00b7ae1c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -0,0 +1,188 @@ +#include "ruby_whisper.h" + +#define ITERATE_MEMBERS(ITERATOR) \ + ITERATOR(id, id, id, id, INT) \ + ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \ + ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \ + ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \ + ITERATOR(probability, probability, p, p, FLOAT) \ + ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \ + ITERATOR(start_time, start_time, start_time, t0, TIME) \ + ITERATOR(end_time, end_time, end_time, t1, TIME) \ + ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL) + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(text, text, text, text, STRING) + +enum { +#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name, + + ITERATE_MEMBERS(DEF_IDX) + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, +}; + +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_FROM_FLOAT(v) (DBL2NUM(v)) +#define VAL_FROM_TIME(v) (LONG2NUM(v * 10)) +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) + +#define READER(type) VAL_FROM_##type +#define MEMBER_NAME(name) name +#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \ + } + +#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return rwpt->p_name; \ + } + +VALUE cParakeetToken; + +#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key; + +ITERATE_MEMBERS(DEC_ATTR_SYMS) +ITERATE_ATTRS(DEC_ATTR_SYMS) + +static void +ruby_whisper_parakeet_token_mark(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + rb_gc_mark(rwpt->text); +} + +static void +ruby_whisper_parakeet_token_free(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (rwpt->token_data) { + xfree(rwpt->token_data); + rwpt->token_data = NULL; + } + xfree(rwpt); +} + +static size_t +ruby_whisper_parakeet_token_memsize(const void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (!rwpt) { + return 0; + } + size_t size = sizeof(*rwpt); + if (rwpt->token_data) { + size += sizeof(*rwpt->token_data); + } + + return size; +} + +static const rb_data_type_t ruby_whisper_parakeet_token_type = { + "ruby_whisper_parakeet_token", + {ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_token_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_token *rwpt; + VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = NULL; + rwpt->text = Qnil; + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) +{ + const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); + ruby_whisper_parakeet_token *rwpt; + TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = ALLOC(parakeet_token_data); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + +ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_token *rwpt; + GetParakeetToken(self, rwpt); + + VALUE hash = rb_hash_new(); + long n_keys = 0; + + if (NIL_P(keys)) { + VALUE attrs[] = { +#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key, + + ITERATE_MEMBERS(LIST_SYMS) + ITERATE_ATTRS(LIST_SYMS) + }; + keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs); + n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) { + return hash; + } + } + for (long i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \ + if (key == sym_##s_key) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \ + } + + ITERATE_MEMBERS(CHECK_AND_SET_KEY) + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_token(VALUE *mParakeet) +{ + cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject); + rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate); + +#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \ + sym_##s_key = ID2SYM(rb_intern(#s_key)); \ + rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0); + + ITERATE_MEMBERS(REGISTER_ATTR) + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 000000000..c4deccce8 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,58 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args); + +extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector pcmf32; + std::vector> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 2aae7c12d..f38e9bde3 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; @@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() { return container; } -static void -rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) -{ - if (rwc == NULL) return; - - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -static ruby_whisper_abort_callback_container* -rb_whisper_abort_callback_container_allocate() { - ruby_whisper_abort_callback_container *container; - container = ALLOC(ruby_whisper_abort_callback_container); - container->context = NULL; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = Qnil; - container->is_interrupted = false; - return container; -} - -static bool +bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } -static bool -ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { - return !NIL_P(container->callback) || !NIL_P(container->callbacks); -} - typedef struct { const ruby_whisper_callback_container *container; struct whisper_state *state; @@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s } typedef struct { - const ruby_whisper_abort_callback_container *container; - struct whisper_state *state; + const ruby_whisper_callback_container *container; bool is_interrupted; } call_abort_callbacks_args; static void* call_abort_callbacks(void *v_args) { call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; - const ruby_whisper_abort_callback_container *container = args->container; - - if (container->is_interrupted) { - args->is_interrupted = true; - return NULL; - } + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) { if (NIL_P(container->callbacks)) { return NULL; } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { return NULL; } - for (int j = 0; j < callbacks_len; j++) { + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) { } static bool abort_callback(void * user_data) { - const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; - if (container->is_interrupted) { + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { return true; } - if (!ruby_whisper_abort_callback_container_is_present(container)) { + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { return false; } call_abort_callbacks_args args = { - container, - NULL, + data->callback_container, false }; rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); @@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) return; } - if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { - rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); - } - - if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { - rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); - } + // new_segment_callback is called only after multiple threads are joined + // progress_callback is not called when parallel if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } - - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (!NIL_P(log_callback)) { - rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); - } } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } + abort_callback_user_data->callback_container = rwp->abort_callback_container; rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; - rwp->abort_callback_container->is_interrupted = false; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } -/* - TODO: Set abort callback to trap SIGINT and SIGTERM -*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { check_thread_safety(rwp, n_processors); - register_callbacks(rwp, context); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -421,10 +376,10 @@ void rb_whisper_params_mark(void *p) { ruby_whisper_params *rwp = (ruby_whisper_params *)p; - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwp->progress_callback_container); + ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass) } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index ee0d66c4c..cf0372797 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -4,12 +4,12 @@ extern ID id___method__; extern ID id_to_enum; -static VALUE sym_start_time; -static VALUE sym_end_time; -static VALUE sym_text; -static VALUE sym_no_speech_prob; -static VALUE sym_speaker_turn_next; -static VALUE sym_n_tokens; +VALUE sym_start_time; +VALUE sym_end_time; +VALUE sym_text; +VALUE sym_no_speech_prob; +VALUE sym_speaker_turn_next; +VALUE sym_n_tokens; extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 37656af1c..73f606ca4 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,6 +16,8 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); +extern VALUE full_parallel_body(VALUE rb_args); typedef struct{ struct whisper_context *context; @@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args) return NULL; } -typedef struct { - ruby_whisper_abort_callback_container *abort_callback_container; -} transcribe_ubf_args; - -static void -transcribe_ubf(void *rb_args) -{ - transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; - - args->abort_callback_container->is_interrupted = true; -} - /* * transcribe a single file * can emit to a block results @@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - - prepare_transcription(rwp, &self, n_processors); - - transcribe_without_gvl_args args = { - rw->context, - &rwp->params, - pcmf32.data(), - pcmf32.size(), - n_processors, - 0, - }; - transcribe_ubf_args ubf_args = { - rwp->abort_callback_container, - }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); - if (args.result != 0) { + VALUE rb_result; + if (n_processors == 1) { + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + rb_result = full_body((VALUE)&args); + } else { + ruby_whisper_full_parallel_args parallel_args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + n_processors, + }; + rb_result = full_parallel_body((VALUE)¶llel_args); + } + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb deleted file mode 100644 index c3a134b77..000000000 --- a/bindings/ruby/lib/whisper/context.rb +++ /dev/null @@ -1,15 +0,0 @@ -module Whisper - class Context - def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| - srt << "#{index + 1}\n#{segment.to_srt_cue}\n" - } - end - - def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| - webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" - } - end - end -end diff --git a/bindings/ruby/lib/whisper/log_settable.rb b/bindings/ruby/lib/whisper/log_settable.rb new file mode 100644 index 000000000..2f8218d26 --- /dev/null +++ b/bindings/ruby/lib/whisper/log_settable.rb @@ -0,0 +1,36 @@ +require "mutex_m" + +module Whisper + module LogSettable + class << self + def extended(base) + base.extend Mutex_m + end + end + + private + + def start_log_callback_thread + return if @log_callback_thread&.alive? + + @log_callback_thread = Thread.new { + begin + while logs = drain_logs + begin + callback, user_data = synchronize {[@log_callback, @log_callback_user_data]} + next if callback.nil? + + logs.each do |(level, text)| + callback.call level, text, user_data + end + rescue => err + $stderr.puts err + end + end + rescue => err + $stderr.puts err + end + } + end + end +end diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 8eb57e5e8..ef92eb901 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -41,6 +41,8 @@ module Whisper def cache path = cache_path + return path if cache_path.exist? + headers = {} headers["if-modified-since"] = path.mtime.httpdate if path.exist? request @uri, headers @@ -216,8 +218,18 @@ module Whisper @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") end + %w[ + parakeet-tdt-0.6b-v3-f16 + parakeet-tdt-0.6b-v3-f32 + parakeet-tdt-0.6b-v3-q4_0 + parakeet-tdt-0.6b-v3-q4_k + parakeet-tdt-0.6b-v3-q8_0 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin") + end + @coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models| - next if name.end_with?("-tdrz") || name.start_with?("silero-") + next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-") if matched = name.match(/\A(?.*)-q\d_\d\z/) name = matched[:name] diff --git a/bindings/ruby/lib/whisper/output.rb b/bindings/ruby/lib/whisper/output.rb new file mode 100644 index 000000000..1781af17a --- /dev/null +++ b/bindings/ruby/lib/whisper/output.rb @@ -0,0 +1,74 @@ +module Whisper + module Output + module Context + def to_srt + each_segment.with_index.reduce("") {|srt, (segment, index)| + srt << "#{index + 1}\n#{segment.to_srt_cue}\n" + } + end + + def to_webvtt + each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" + } + end + end + + module Segment + SRT_ESCAPES = { + "&" => "&", + "<" => "<", + ">" => ">", + } + SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) + private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE + + def to_srt_cue + "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" + end + + def to_webvtt_cue + "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" + end + + private + + def time_to_a(time) + sec, decimal_part = time.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + [hour, min, sec, decimal_part] + end + + def srt_time(time) + "%02d:%02d:%02d,%03d" % time_to_a(time) + end + + def srt_start_time + srt_time(start_time) + end + + def srt_end_time + srt_time(end_time) + end + + def srt_text + text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) + end + + def webvtt_time(time) + "%02d:%02d:%02d.%03d" % time_to_a(time) + end + + def webvtt_start_time + webvtt_time(start_time) + end + + def webvtt_end_time + webvtt_time(end_time) + end + + alias webvtt_text srt_text + end + end +end diff --git a/bindings/ruby/lib/whisper/segment.rb b/bindings/ruby/lib/whisper/segment.rb deleted file mode 100644 index dc187dcac..000000000 --- a/bindings/ruby/lib/whisper/segment.rb +++ /dev/null @@ -1,58 +0,0 @@ -module Whisper - class Segment - SRT_ESCAPES = { - "&" => "&", - "<" => "<", - ">" => ">", - } - SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) - private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE - - def to_srt_cue - "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" - end - - def to_webvtt_cue - "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" - end - - private - - def time_to_a(time) - sec, decimal_part = time.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - [hour, min, sec, decimal_part] - end - - def srt_time(time) - "%02d:%02d:%02d,%03d" % time_to_a(time) - end - - def srt_start_time - srt_time(start_time) - end - - def srt_end_time - srt_time(end_time) - end - - def srt_text - text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) - end - - def webvtt_time(time) - "%02d:%02d:%02d.%03d" % time_to_a(time) - end - - def webvtt_start_time - webvtt_time(start_time) - end - - def webvtt_end_time - webvtt_time(end_time) - end - - alias webvtt_text srt_text - end -end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index cbec48038..c12e1fe55 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -40,7 +40,21 @@ module Whisper def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String + module Output + module Context + def to_srt: () -> String + def to_webvtt: () -> String + end + + module Segment + def to_srt_cue: () -> String + def to_webvtt_cue: () -> String + end + end + class Context + include Output::Context + def self.new: (String | path | ::URI::HTTP) -> instance # transcribe a single file @@ -139,17 +153,14 @@ module Whisper | (Whisper::Params, _Samples, ?Integer n_samples) -> self | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self - def to_srt: () -> String - def to_webvtt: () -> String - class Params def self.new: ( - use_gpu: boolish, - flash_attn: boolish, - gpu_device: Integer, - dtw_token_timestamps: boolish, - dtw_aheads_preset: Integer, - dtw_n_top: Integer | nil, + ?use_gpu: boolish, + ?flash_attn: boolish, + ?gpu_device: Integer, + ?dtw_token_timestamps: boolish, + ?dtw_aheads_preset: Integer, + ?dtw_n_top: Integer | nil, ) -> instance def use_gpu=: (boolish) -> boolish @@ -444,6 +455,9 @@ module Whisper def abort_on: { (Object user_data) -> boolish } -> void end + module LogSettable + end + class Model def self.pre_converted_models: () -> Hash[String, Model::URI] def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI] @@ -474,6 +488,8 @@ module Whisper end class Segment + include Output::Segment + type deconstructed_keys = { start_time: (Integer | nil), end_time: (Integer | nil), @@ -514,9 +530,6 @@ module Whisper # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] - def to_srt_cue: () -> String - def to_webvtt_cue: () -> String - # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # @@ -528,7 +541,7 @@ module Whisper def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys end - module Token + class Token type deconstructed_keys = { id: (Integer | nil), tid: (Integer | nil), @@ -598,6 +611,336 @@ module Whisper def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys end + module Parakeet + extend LogSettable + + VERSION: String + + # Control logging output. The default behavior is to print to stderr. + # + def self.log_set: (nil, Object? user_data) -> nil + | (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil + def self.system_info_str: () -> String + + class Context + include Output::Context + + # Load a Parakeet model from the given file path. + # + def self.new: (String | path | ::URI::HTTP, ?Params) -> instance + + # Transcribe a single audio file. + # + def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self + + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + # Not thread safe for the same context. + # + # The second argument `samples` must be an array of samples, respond to `:length`, + # or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # + def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self + + # Number of generated text segments. + # + def full_n_segments: () -> Integer + + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t0(3) # => 1668 (16680 ms) + # + def full_get_segment_t0: (Integer segment_index) -> Integer + + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t1(3) # => 1668 (16680 ms) + # + def full_get_segment_t1: (Integer segment_index) -> Integer + + # Text of a segment indexed by `segment_index`. + # + # full_get_segment_text(3) # => "ask not what your country can do for you, ..." + # + def full_get_segment_text: (Integer segment_index) -> String + + # Number of tokens in the segment indexed by `segment_index`. + # + def full_n_tokens: (Integer segment_index) -> Integer + + # Text of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_text: (Integer segment_index, Integer token_index) -> String + + # Token id of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer + + # Probability of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_p: (Integer segment_index, Integer token_index) -> Float + + # Token data of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_data: (Integer segment_index, Integer token_index) -> Token + + def model: () -> Model + + # Yields each Whisper::Parakeet::Segment: + # + # parakeet.transcribe("path/to/audio.wav", params) + # parakeet.each_segment do |segment| + # puts segment.text + # end + # + # Returns an `Enumerator` if no block given: + # + # parakeet.transcribe("path/to/audio.wav", params) + # enum = parakeet.each_segment + # enum.to_a # => [#, ...] + # + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end + end + + class Params + def self.new: ( + ?n_threads: Integer, + ?offset_ms: Integer, + ?duration_ms: Integer, + ?no_context: boolish, + ?audio_ctx: Integer, + ?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void, + ?new_segment_callback_user_data: Object, + ?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void, + ?new_token_callback_user_data: Object, + ?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void, + ?progress_callback_user_data: Object, + ?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish, + ?encoder_begin_callback_user_data: Object, + ?abort_callback: ^(Object user_data) -> boolish, + ?abort_callback_user_data: Object + ) -> instance + + # Number of threads to use. + # + def n_threads=: (Integer) -> Integer + def n_threads: () -> Integer + + # Start offset in ms. + # + def offset_ms=: (Integer) -> Integer + def offset_ms: () -> Integer + + # Audio duration to process in ms. + # + def duration_ms=: (Integer) -> Integer + def duration_ms: () -> Integer + + # If `true`, does not use past transcription (if any) as context. + # + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + + # Overwrite the audio context size. `0` uses the default value. + # + def audio_ctx=: (Integer) -> Integer + def audio_ctx: () -> Integer + + # Sets new segment callback, called for every newly generated text segment. + # + # params.new_segment_callback = ->(context, _, n_new, user_data) { + # # ... + # } + # + def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) + def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of new segment callback. + # + def new_segment_callback_user_data=: (Object?) -> Object? + def new_segment_callback_user_data: () -> Object? + + # Sets token callback, called for every newly predicted token. + # + def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) + def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of token callback. + # + def new_token_callback_user_data=: (Object?) -> Object? + def new_token_callback_user_data: () -> Object? + + # Sets progress callback, called on each progress update. + # + # +progress+ is an Integer between 0 and 100. + # + def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) + def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of progress callback. + # + def progress_callback_user_data=: (Object?) -> Object? + def progress_callback_user_data: () -> Object? + + # Sets encoder begin callback, called each time before the encoder starts. + # + # If it returns `false`, the computation is aborted. + # + def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) + def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object?) -> Object? + def encoder_begin_callback_user_data: () -> Object? + + # Sets abort callback, called each time before ggml computation starts. + # + def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish) + def abort_callback: () -> ((^(Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of abort callback. + # + def abort_callback_user_data=: (Object?) -> Object? + def abort_callback_user_data: () -> Object? + + # Hook called on new segment. Yields each Whisper::Parakeet::Segment. + # + def on_new_segment: { (Segment) -> void } -> void + + # Hook called on new token. Yields each Whisper::Parakeet::Token. + # + def on_new_token: { (Token) -> void } -> void + + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. + # + def on_progress: { (Integer progress) -> void } -> void + + # Hook called each time before the encoder starts. + # + def on_encoder_begin: { () -> boolish } -> void + + # Call block to determine whether abort or not. Return `true` when you want to abort. + # + def abort_on: { () -> boolish } -> void + end + + class Segment + include Output::Segment + + type deconstructed_keys = { + start_time: (Integer | nil), + end_time: (Integer | nil), + text: (String | nil) + } + + # Start time in milliseconds. + # + def start_time: () -> Integer + + # End time in milliseconds. + # + def end_time: () -> Integer + + # Text of the segment. + # + def text: () -> String + + # Yields each Whisper::Parakeet::Token: + # + # parakeet.each_segment.first.each_token do |token| + # p token + # end + # + # Returns an `Enumerator` if no block is given: + # + # parakeet.each_segment.first.each_token.to_a # => [#, ...] + # + def each_token: { (Token) -> void } -> void + | () -> Enumerator[Token] + + # Possible keys: `:start_time`, `:end_time`, `:text` + # + def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys + end + + class Token + type deconstructed_keys = { + id: (Integer | nil), + duration_idx: (Integer | nil), + duration_value: (Integer | nil), + frame_index: (Integer | nil), + probability: (Float | nil), + log_probability: (Float | nil), + start_time: (Integer | nil), + end_time: (Integer | nil), + word_start: ((true | false) | nil), + text: (String | nil), + } + + # Token ID. + # + def id: () -> Integer + + # Index into the model's durations array. + # + def duration_idx: () -> Integer + + # Actual duration value. + # + def duration_value: () -> Integer + + # Frame index of the token. + # + def frame_index: () -> Integer + + # Probability of the token. + # + def probability: () -> Float + + # Log probability of the token. + # + def log_probability: () -> Float + + # Start time of the token in milliseconds. + # + def start_time: () -> Integer + + # End time of the token in milliseconds. + # + def end_time: () -> Integer + + # Whether this token is the start of a word. + # + def word_start?: () -> (true | false) + + # Get the token text of the token. + # + def text: () -> String + + def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys + end + + class Model + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + end + end + module VAD class Params def self.new: ( diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb index 56cd3849f..5e37ad985 100644 --- a/bindings/ruby/test/helper.rb +++ b/bindings/ruby/test/helper.rb @@ -5,6 +5,8 @@ require_relative "jfk_reader/jfk_reader" class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "fixtures", "jfk.wav") + Parakeet = Whisper::Parakeet + class << self def whisper return @whisper if @whisper diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb index a7f49245a..6490c8abb 100644 --- a/bindings/ruby/test/test_callback.rb +++ b/bindings/ruby/test/test_callback.rb @@ -129,6 +129,7 @@ class TestCallback < TestBase return false } @whisper.transcribe(@audio, @params) + sleep 0.5 # wait for logs dequeued assert_match(/encoder_begin_callback returned false - aborting/, logs.join) Whisper.log_set ->(level, buffer, user_data) {}, nil end diff --git a/bindings/ruby/test/test_parakeet.rb b/bindings/ruby/test/test_parakeet.rb new file mode 100644 index 000000000..bfd57076f --- /dev/null +++ b/bindings/ruby/test/test_parakeet.rb @@ -0,0 +1,28 @@ +require_relative "helper" +require "stringio" + +class TestParakeet < TestBase + def test_log_set + log_callback = Parakeet.instance_variable_get("@log_callback") + user_data = Parakeet.instance_variable_get("@log_callback_user_data") + + $stdout = StringIO.new + Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil + Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + sleep 0.1 + $stdout.rewind + logs = $stdout.string + assert_match /loading model from/, logs + ensure + $stdout = STDOUT + Parakeet.log_set log_callback, user_data + end + + def test_system_info_str + assert_match /\APARAKEET : /, Parakeet.system_info_str + end + + def test_version + assert_instance_of String, Parakeet::VERSION + end +end diff --git a/bindings/ruby/test/test_parakeet_callback.rb b/bindings/ruby/test/test_parakeet_callback.rb new file mode 100644 index 000000000..1209e960f --- /dev/null +++ b/bindings/ruby/test/test_parakeet_callback.rb @@ -0,0 +1,107 @@ +require_relative "helper" + +class TestParakeetCallback < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + @params = Parakeet::Params.new + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @parakeet, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0 + end + } + + @parakeet.transcribe AUDIO, @params + end + + def test_on_new_segment + seg = nil + index = 0 + @params.on_new_segment do |segment| + assert_instance_of Parakeet::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text) + end + index += 1 + end + @parakeet.transcribe AUDIO, @params + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text + end + + def test_on_new_token + index = 0 + @params.on_new_token do |token| + assert_instance_of Parakeet::Token, token + if index == 0 + assert_instance_of Integer, token.start_time + assert_match "▁And", token.text + end + index += 1 + end + + @parakeet.transcribe AUDIO, @params + end + + def test_on_progress + first = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + end + + @parakeet.transcribe AUDIO, @params + + assert_equal 0, first + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + + @parakeet.transcribe AUDIO, @params + + assert i > 0 + end + + def test_abort_on + do_abort = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match?(/ask/) + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + + @parakeet.transcribe(AUDIO, @params) rescue nil + + assert i > 0 + end +end diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb new file mode 100644 index 000000000..2d039ce75 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -0,0 +1,116 @@ +require_relative "helper" +require "stringio" + +class TestParakeetContext < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Context, @parakeet + end + + def test_new_with_params + log_callback = Parakeet.instance_variable_get(:@log_callback) + user_data = Parakeet.instance_variable_get(:@log_callback_user_data) + begin + logs = "" + Parakeet.log_set proc {|level, message| logs << message}, nil + params = Parakeet::Context::Params.new(use_gpu: false) + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params) + assert_instance_of Parakeet::Context, parakeet + assert_match /use gpu\s+=\s+0/, logs + ensure + Parakeet.log_set log_callback, user_data + end + end + + sub_test_case "full" do + def setup + super + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} + end + + def test_full + @parakeet.full @params, @samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text + end + + def test_full_without_length + @parakeet.full(@params, @samples) + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator + samples = @samples.each + @parakeet.full @params, samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @parakeet.full @params, samples + end + end + + def test_full_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @parakeet.full @params, samples, 11 + end + end + + def test_full_with_memory_view + samples = JFKReader.new(AUDIO) + @parakeet.full @params, samples + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @parakeet.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end + end +end diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 000000000..fcd0f2410 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end diff --git a/bindings/ruby/test/test_parakeet_model.rb b/bindings/ruby/test/test_parakeet_model.rb new file mode 100644 index 000000000..5343b35ed --- /dev/null +++ b/bindings/ruby/test/test_parakeet_model.rb @@ -0,0 +1,21 @@ +require_relative "helper" + +class TestParakeetModel < TestBase + def test_model + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + assert_instance_of Parakeet::Model, parakeet.model + end + + def test_attributes + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + model = parakeet.model + + assert_equal 10, model.n_vocab + assert_equal 3200, model.n_audio_ctx + assert_equal 8, model.n_audio_state + assert_equal 2, model.n_audio_head + assert_equal 1, model.n_audio_layer + assert_equal 16, model.n_mels + assert_equal 0, model.ftype + end +end diff --git a/bindings/ruby/test/test_parakeet_params.rb b/bindings/ruby/test/test_parakeet_params.rb new file mode 100644 index 000000000..dc651f7ab --- /dev/null +++ b/bindings/ruby/test/test_parakeet_params.rb @@ -0,0 +1,78 @@ +require_relative "helper" +require "etc" + +class TestParakeetParams < TestBase + PARAM_NAMES = [ + :n_threads, + :offset_ms, + :duration_ms, + :no_context, + :audio_ctx + ] + + def setup + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Params, @params + end + + def test_n_threads + assert_equal [4, Etc.nprocessors].min, @params.n_threads + + @params.n_threads = 1 + assert_equal 1, @params.n_threads + end + + def test_offset_ms + assert_equal 0, @params.offset_ms + + @params.offset_ms = 10_000 + assert_equal 10_000, @params.offset_ms + end + + def test_duration_ms + assert_equal 0, @params.duration_ms + + @params.duration_ms = 60_000 + assert_equal 60_000, @params.duration_ms + end + + def test_no_context + assert_equal true, @params.no_context + + @params.no_context = false + assert_equal false, @params.no_context + end + + def test_audio_ctx + assert_equal 0, @params.audio_ctx + + @params.audio_ctx = 1 + assert_equal 1, @params.audio_ctx + end + + def test_new_with_kw_args + params = Parakeet::Params.new(n_threads: 1) + assert_equal 1, params.n_threads + assert_equal 0, params.offset_ms + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer] + default_value + 1 + end + params = Parakeet::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + assert_equal @params.send(name), params.send(name) + end + end +end diff --git a/bindings/ruby/test/test_parakeet_segment.rb b/bindings/ruby/test/test_parakeet_segment.rb new file mode 100644 index 000000000..d5b99bd5e --- /dev/null +++ b/bindings/ruby/test/test_parakeet_segment.rb @@ -0,0 +1,42 @@ +require_relative "helper" + +class TestParakeetSegment < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @parakeet.transcribe AUDIO, Parakeet::Params.new + end + + def test_segment + whole_text = "" + @parakeet.each_segment do |segment| + assert_instance_of Parakeet::Segment, segment + assert_kind_of Integer, segment.start_time + assert segment.end_time >= segment.start_time + assert_kind_of String, segment.text + whole_text << segment.text + end + assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text) + end + + def test_deconstruct_keys + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text]) + end + + def test_deconstruct_keys_with_nil + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys(nil) + end +end diff --git a/bindings/ruby/test/test_parakeet_token.rb b/bindings/ruby/test/test_parakeet_token.rb new file mode 100644 index 000000000..6f0b8b5a3 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_token.rb @@ -0,0 +1,73 @@ +require_relative "helper" + +class TestParakeetToken < TestBase + ATTRS = %i[ + id + duration_idx + duration_value + frame_index + probability + log_probability + start_time + end_time + word_start? + text + ] + + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + params = Parakeet::Params.new + parakeet.transcribe AUDIO, params + @segment = parakeet.each_segment.first + end + + def test_each_token + i = 0 + @segment.each_token do |token| + i += 1 + assert_instance_of Parakeet::Token, token + end + assert_equal 38, i + end + + def test_each_token_without_block + assert_instance_of Enumerator, @segment.each_token + end + + def test_token + token = @segment.each_token.first + + assert_instance_of Parakeet::Token, token + assert_instance_of Integer, token.id + assert_instance_of Integer, token.duration_idx + assert_instance_of Integer, token.duration_value + assert_instance_of Integer, token.frame_index + assert_instance_of Float, token.probability + assert_instance_of Float, token.log_probability + assert_instance_of Integer, token.start_time + assert_instance_of Integer, token.end_time + assert_instance_of String, token.text + end + + def test_text + assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."], + @segment.each_token.collect(&:text) + end + + def test_deconstruct_keys_with_nil + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(nil) + end + + def test_deconstruct_keys_with_keys + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(expected.keys) + end +end diff --git a/bindings/ruby/test/test_vad_segment.rb b/bindings/ruby/test/test_vad_segment.rb index 7348562cb..6d66c27fd 100644 --- a/bindings/ruby/test/test_vad_segment.rb +++ b/bindings/ruby/test/test_vad_segment.rb @@ -9,7 +9,7 @@ class TestVADSegment < TestBase end assert_raise do - segments.end_time + segment.end_time end assert_raise do diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index f7e25239d..082547e7c 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -149,6 +149,7 @@ class TestWhisper < TestBase } Whisper.log_set log_callback, user_data Whisper::Context.new("base.en") + sleep 0.1 # wait for logs dequeued assert logs.length > 30 logs.each do |log| diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2d952222f..301ecfcc1 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "test/"} s.extensions << 'ext/extconf.rb' - s.required_ruby_version = '>= 3.1.0' + s.required_ruby_version = '>= 3.3.0' #### Documentation and testing. s.homepage = 'https://github.com/ggml-org/whisper.cpp'