diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 78bc3fc3e..0e907a614 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -134,6 +134,13 @@ typedef struct { } \ } while (0) +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #define GetParakeetParams(obj, rwpp) do { \ TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ if (!(rwpp)->new_segment_callback_container || \ diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c index cefb1b233..226e65e10 100644 --- a/bindings/ruby/ext/ruby_whisper_parakeet_context.c +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -3,6 +3,7 @@ extern ID id_to_s; extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); static void ruby_whisper_parakeet_context_free(void *p) @@ -25,7 +26,7 @@ ruby_whisper_parakeet_context_memsize(const void *p) return size; } -static const rb_data_type_t ruby_whisper_parakeet_context_type = { +const rb_data_type_t ruby_whisper_parakeet_context_type = { "ruby_whisper_parakeet_context", {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, 0, 0, @@ -72,4 +73,5 @@ init_ruby_whisper_parakeet_context(VALUE *mParakeet) rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); } diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 000000000..ba713dc8c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,47 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern ID id_to_s; +extern ID id_to_path; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector pcmf32; + std::vector> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + if (parakeet_full(rwpc->context, rwpp->params, pcmf32.data(), pcmf32.size()) != 0) { + rb_raise(rb_eRuntimeError, "Failed to process audio"); + return Qnil; + } + + return self; +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb index 295112bdd..9150fe926 100644 --- a/bindings/ruby/test/test_parakeet_context.rb +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -1,7 +1,30 @@ require_relative "helper" class TestParakeetContext < TestBase + def setup + @parakeet = Parakeet::Context.new(File.join(__dir__, "../../../models/parakeet-tdt-0.6b-v3.bin")) + @params = Parakeet::Params.new + end + def test_new - assert_instance_of Parakeet::Context, Parakeet::Context.new(File.join(__dir__, "../../../models/parakeet-tdt-0.6b-v3.bin")) + assert_instance_of Parakeet::Context, @parakeet + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end end end