diff --git a/bindings/ruby/.document b/bindings/ruby/.document new file mode 100644 index 00000000..a8e9788f --- /dev/null +++ b/bindings/ruby/.document @@ -0,0 +1,3 @@ +README.md +LICENSE +sig diff --git a/bindings/ruby/.rdoc_options b/bindings/ruby/.rdoc_options new file mode 100644 index 00000000..cf14aa5f --- /dev/null +++ b/bindings/ruby/.rdoc_options @@ -0,0 +1,2 @@ +title: whispercpp +main_page: README.md diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 41e7b330..07b81830 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -360,7 +360,7 @@ Whisper::Context.new("base") ### Low-level API to transcribe ### -You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. +You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. Unlike `#transcribe`, these methods requires 16,000 Hz, 32-bit float audio. ```ruby require "whisper" @@ -383,16 +383,16 @@ If you can prepare audio data as C array and export it as a MemoryView, whisperc ```ruby require "torchaudio" -require "arrow-numo-narray" +require "ndav/torch/tensor" require "whisper" waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") -# Convert Torch::Tensor to Arrow::Array via Numo::NArray -samples = waveform.squeeze.numo.to_arrow.to_arrow_array +# Convert Torch::Tensor to NDAV +samples = waveform.squeeze.to_ndav whisper = Whisper::Context.new("base") whisper - # Arrow::Array exports MemoryView + # NDAV exports MemoryView .full(Whisper::Params.new, samples) ``` diff --git a/bindings/ruby/ext/dependencies.rb b/bindings/ruby/ext/dependencies.rb index 2ba4b94b..e77ac0c4 100644 --- a/bindings/ruby/ext/dependencies.rb +++ b/bindings/ruby/ext/dependencies.rb @@ -36,8 +36,7 @@ class Dependencies end def generate_dot - args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"] - args << @options.to_s unless @options.to_s.empty? + args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", "-C", @options.cache_path] system @cmake, *args, exception: true end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index acff501a..ce9ffc0e 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -3,7 +3,7 @@ require_relative "options" require_relative "dependencies" cmake = find_executable("cmake") || abort -options = Options.new(cmake).to_s +options = Options.new(cmake) have_library("gomp") rescue nil libs = Dependencies.new(cmake, options).to_s @@ -17,7 +17,7 @@ create_makefile "whisper" do |conf| $(TARGET_SO): #{libs} #{libs}: cmake-targets cmake-targets: - #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} + #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON -C #{options.cache_path} #{"\t"}#{cmake} --build build --config Release --target common whisper EOF end diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index ede80c06..5fe600a6 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -1,16 +1,16 @@ +require "fileutils" + class Options def initialize(cmake="cmake") @cmake = cmake @options = {} configure + write_cache_file end - def to_s - @options - .reject {|name, (type, value)| value.nil?} - .collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"} - .join(" ") + def cache_path + File.join(__dir__, "source", "Options.cmake") end def cmake_options @@ -18,7 +18,7 @@ class Options output = nil Dir.chdir __dir__ do - output = `#{@cmake.shellescape} -S sources -B build -L` + output = IO.popen([@cmake, "-S", "sources", "-B", "build", "-L"]) {|io| io.read} end @cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) .filter_map {|line| @@ -82,4 +82,22 @@ class Options op[1] end end + + def write_cache_file + FileUtils.mkpath File.dirname(cache_path) + File.open cache_path, "w" do |file| + @options.reject {|name, (type, value)| value.nil?}.each do |name, (type, value)| + line = "set(CACHE{%s} TYPE %s FORCE VALUE %s)" % { + name:, + type:, + value: value == true ? "ON" : value == false ? "OFF" : escape_cmake(value) + } + file.puts line + end + end + end + + def escape_cmake(str) + str.gsub(/([\\"])/, '\\\\\1') + end end diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 5f1917ee..56fceb1c 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -29,6 +29,7 @@ ID id_cache; ID id_n_processors; static bool is_log_callback_finalized = false; +static bool is_ruby_log_callback_present = false; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -106,18 +107,43 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { return Qnil; } +typedef struct { + int level; + const char * buffer; +} call_log_callbacks_args; + +static void* +call_log_callbacks(void *v_args) { + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return NULL; + } + + call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; + VALUE user_data = rb_iv_get(mWhisper, "user_data"); + rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); + + return NULL; +} + static void ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { if (is_log_callback_finalized) { return; } - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { + if (!is_ruby_log_callback_present) { return; } - VALUE udata = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); + call_log_callbacks_args args = { + level, + buffer, + }; + if (ruby_thread_has_gvl_p()) { + call_log_callbacks((void *)&args); + } else { + rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); + } } /* @@ -140,8 +166,10 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d if (NIL_P(log_callback)) { whisper_log_set(NULL, NULL); + is_ruby_log_callback_present = false; } else { whisper_log_set(ruby_whisper_log_callback, NULL); + is_ruby_log_callback_present = true; } return Qnil; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 6b0b4df7..6a2d4585 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -3,6 +3,7 @@ #include #include +#include #include #include "whisper.h" @@ -13,6 +14,14 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; + VALUE callbacks; + bool is_interrupted; +} ruby_whisper_abort_callback_container; + typedef struct { struct whisper_context *context; } ruby_whisper; @@ -27,7 +36,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_callback_container *abort_callback_container; + ruby_whisper_abort_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 6e38ead6..2428aeff 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -47,6 +47,27 @@ typedef struct full_parallel_args { int n_processors; } full_parallel_args; +typedef struct full_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int result; +} full_without_gvl_args; + +typedef struct full_parallel_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int n_processors; + int result; +} full_parallel_without_gvl_args; + +typedef struct full_ubf_args { + ruby_whisper_abort_callback_container *abort_callback_container; +} full_ubf_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -74,7 +95,7 @@ static size_t ruby_whisper_memsize(const void *p) { const ruby_whisper *rw = (const ruby_whisper *)p; - size_t size = sizeof(rw); + size_t size = sizeof(*rw); if (!rw) { return 0; } @@ -304,11 +325,13 @@ VALUE ruby_whisper_model_type(VALUE self) static bool check_memory_view(rb_memory_view_t *memview) { - if (memview->format != NULL && strcmp(memview->format, "f") != 0) { - rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + if (memview->format != NULL && strcmp(memview->format, "f") != 0 && strcmp(memview->format, "e") != 0) { + // TODO: Accept other formats and convert samples + rb_warn("currently only format \"f\" and \"e\" is supported for MemoryView, but given: %s", memview->format); return false; } - if (memview->format != NULL && memview->ndim != 1) { + if (memview->format != NULL && memview->ndim != 1 && !(memview->ndim == 2 && memview->shape[1] == 1)) { + // TODO: Accept ndim == 2 with shape [n_samples, channels] and channels > 1 by averaging the samples in different channels or just taking the first channel rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); return false; } @@ -426,6 +449,22 @@ release_samples(VALUE rb_parsed_args) return Qnil; } +static void* +full_without_gvl(void *rb_args) +{ + full_without_gvl_args *args = (full_without_gvl_args *)rb_args; + args->result = whisper_full(args->context, *args->params, args->samples, args->n_samples); + return NULL; +} + +static void +full_ubf(void *rb_args) +{ + full_ubf_args *args = (full_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} + static VALUE full_body(VALUE rb_args) { @@ -437,9 +476,19 @@ full_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, 1); - int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); - return INT2NUM(result); + struct full_without_gvl_args full_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_without_gvl_args.result); } /* @@ -477,6 +526,14 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static void* +full_parallel_without_gvl(void *rb_args) +{ + full_parallel_without_gvl_args *args = (full_parallel_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + return NULL; +} + static VALUE full_parallel_body(VALUE rb_args) { @@ -488,9 +545,20 @@ full_parallel_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, args->n_processors); - int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); - return INT2NUM(result); + struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + args->n_processors, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_parallel_without_gvl_args.result); } /* diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 3e5dca9c..2aae7c12 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -93,21 +93,66 @@ rb_whisper_callback_container_allocate() { container->context = NULL; container->user_data = Qnil; container->callback = Qnil; - container->callbacks = rb_ary_new(); + container->callbacks = Qnil; return container; } -static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; +static void +rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) +{ + if (rwc == NULL) return; + + rb_gc_mark(rwc->user_data); + rb_gc_mark(rwc->callback); + rb_gc_mark(rwc->callbacks); +} + +static ruby_whisper_abort_callback_container* +rb_whisper_abort_callback_container_allocate() { + ruby_whisper_abort_callback_container *container; + container = ALLOC(ruby_whisper_abort_callback_container); + container->context = NULL; + container->user_data = Qnil; + container->callback = Qnil; + container->callbacks = Qnil; + container->is_interrupted = false; + return container; +} + +static bool +ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +static bool +ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int n_new; +} call_new_segment_callbacks_args; + +static void* +call_new_segment_callbacks(void *v_args) { + call_new_segment_callbacks_args *args = (call_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + struct whisper_state *state = args->state; + int n_new = args->n_new; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } const int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; i--) { @@ -118,95 +163,208 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta rb_funcall(cb, id_call, 1, segment); } } + + return NULL; +} + +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_new_segment_callbacks_args args = { + container, + state, + n_new + }; + rb_thread_call_with_gvl(call_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int progress_cur; +} call_progress_callbacks_args; + +static void* +call_progress_callbacks(void *v_args) { + call_progress_callbacks_args *args = (call_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + int progress_cur = args->progress_cur; + + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(args->container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(progress_cur), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return NULL; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, INT2NUM(progress_cur)); + } + + return NULL; } static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because - // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + if (!ruby_whisper_callback_container_is_present(container)) { return; } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); - } + + call_progress_callbacks_args args = { + container, + state, + progress_cur + }; + rb_thread_call_with_gvl(call_progress_callbacks, (void *)&args); } -static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - bool is_aborted = false; - VALUE result; +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + bool is_continued; +} call_encoder_begin_callbacks_args; + +static void* +call_encoder_begin_callbacks(void *v_args) { + call_encoder_begin_callbacks_args *args = (call_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); if (result == Qfalse) { - is_aborted = true; + args->is_continued = false; + return NULL; } } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return !is_aborted; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - result = rb_funcall(cb, id_call, 0); - if (result == Qfalse) { - is_aborted = true; + if (!NIL_P(container->callbacks)) { + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return NULL; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } } } - return !is_aborted; + + return NULL; } -static bool abort_callback(void * user_data) { +static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_encoder_begin_callbacks_args args = { + container, + state, + true + }; + rb_thread_call_with_gvl(call_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_abort_callback_container *container; + struct whisper_state *state; + bool is_interrupted; +} call_abort_callbacks_args; + +static void* +call_abort_callbacks(void *v_args) { + call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; + const ruby_whisper_abort_callback_container *container = args->container; + + if (container->is_interrupted) { + args->is_interrupted = true; + return NULL; + } + if (!NIL_P(container->callback)) { VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return false; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); VALUE result = rb_funcall(cb, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } - return false; + + return NULL; +} + +static bool abort_callback(void * user_data) { + const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + + if (container->is_interrupted) { + return true; + } + + if (!ruby_whisper_abort_callback_container_is_present(container)) { + return false; + } + + call_abort_callbacks_args args = { + container, + NULL, + false + }; + rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); + + return args.is_interrupted; } static void -check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +check_thread_safety(ruby_whisper_params *rwp, int n_processors) { if (n_processors == 1) { return; } - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } @@ -217,29 +375,28 @@ check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) } static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rwp->progress_callback_container->context = context; rwp->params.progress_callback = progress_callback; rwp->params.progress_callback_user_data = rwp->progress_callback_container; } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rwp->encoder_begin_callback_container->context = context; rwp->params.encoder_begin_callback = encoder_begin_callback; rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->abort_callback_container->context = context; - rwp->params.abort_callback = abort_callback; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } + rwp->abort_callback_container->context = context; + rwp->params.abort_callback = abort_callback; + rwp->abort_callback_container->is_interrupted = false; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; } static void set_vad_params(ruby_whisper_params *rwp) @@ -255,7 +412,7 @@ static void set_vad_params(ruby_whisper_params *rwp) void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { - check_thread_safety(rwp, context, n_processors); + check_thread_safety(rwp, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -267,7 +424,7 @@ rb_whisper_params_mark(void *p) rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_callbcack_container_mark(rwp->abort_callback_container); + rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -338,7 +495,7 @@ ruby_whisper_params_allocate(VALUE klass) rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_callback_container_allocate(); + rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); return obj; } @@ -1302,6 +1459,9 @@ ruby_whisper_params_on_new_segment(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); return Qnil; } @@ -1322,6 +1482,9 @@ ruby_whisper_params_on_progress(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->progress_callback_container->callbacks, blk); return Qnil; } @@ -1342,6 +1505,9 @@ ruby_whisper_params_on_encoder_begin(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->encoder_begin_callback_container->callbacks)) { + rwp->encoder_begin_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk); return Qnil; } @@ -1366,6 +1532,9 @@ ruby_whisper_params_abort_on(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->abort_callback_container->callbacks, blk); return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 3d005660..37656af1 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -15,8 +15,37 @@ extern ID id_call; extern ID id_to_path; extern ID transcribe_option_names[1]; -extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); + +typedef struct{ + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + size_t n_samples; + int n_processors; + int result; +} transcribe_without_gvl_args; + +static void* +transcribe_without_gvl(void *rb_args) +{ + transcribe_without_gvl_args *args = (transcribe_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_container *abort_callback_container; +} transcribe_ubf_args; + +static void +transcribe_ubf(void *rb_args) +{ + transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} /* * transcribe a single file @@ -75,7 +104,19 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { prepare_transcription(rwp, &self, n_processors); - if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { + transcribe_without_gvl_args args = { + rw->context, + &rwp->params, + pcmf32.data(), + pcmf32.size(), + n_processors, + 0, + }; + transcribe_ubf_args ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); + if (args.result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 3c596619..cbec4803 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -5,10 +5,10 @@ module Whisper end type log_callback = ^(Integer level, String message, Object user_data) -> void - type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void - type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void - type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void - type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish + type new_segment_callback = ^(Whisper::Context, untyped, Integer n_new, Object user_data) -> void + type progress_callback = ^(Whisper::Context, untyped, Integer progress, Object user_data) -> void + type encoder_begin_callback = ^(Whisper::Context, untyped, Object user_data) -> void + type abort_callback = ^(Whisper::Context, untyped, Object user_data) -> boolish VERSION: String LOG_LEVEL_NONE: Integer @@ -52,11 +52,11 @@ module Whisper # puts text # end # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def transcribe: (path, Params, ?n_processors: Integer) -> self - | (path, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Whisper::Params, ?n_processors: Integer) -> self + | (path, Whisper::Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -74,7 +74,7 @@ module Whisper # puts segment.text # end # - # Returns an Enumerator if no block given: + # Returns an `Enumerator` if no block given: # # whisper.transcribe("path/to/audio.wav", params) # enum = whisper.each_segment @@ -91,25 +91,25 @@ module Whisper # def full_lang_id: () -> Integer - # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t0(3) # => 1668 (16680 ms) # def full_get_segment_t0: (Integer) -> Integer - # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t1(3) # => 1668 (16680 ms) # def full_get_segment_t1: (Integer) -> Integer - # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + # Whether the next segment indexed by `segment_index` is predicated as a speaker turn. # # full_get_segment_speacker_turn_next(3) # => true # def full_get_segment_speaker_turn_next: (Integer) -> (true | false) - # Text of a segment indexed by +segment_index+. + # Text of a segment indexed by `segment_index`. # # full_get_segment_text(3) # => "ask not what your country can do for you, ..." # @@ -117,27 +117,27 @@ module Whisper def full_get_segment_no_speech_prob: (Integer) -> Float - # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - # Not thread safe for same context + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + # Not thread safe for same context # Uses the specified decoding strategy to obtain the text. # - # The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # The second argument `samples` must be an array of samples, respond to `:length`, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. # - def full: (Params, Array[Float] samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self + def full: (Whisper::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self - # Split the input audio in chunks and process each chunk separately using whisper_full_with_state() - # Result is stored in the default state of the context - # Not thread safe if executed in parallel on the same context. - # It seems this approach can offer some speedup in some cases. + # Split the input audio in chunks and process each chunk separately using `whisper_full_with_state()` + # Result is stored in the default state of the context + # Not thread safe if executed in parallel on the same context. + # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self + def full_parallel: (Whisper::Params, Array[Float], ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self def to_srt: () -> String def to_webvtt: () -> String @@ -217,35 +217,35 @@ module Whisper def translate: () -> (true | false) def no_context=: (boolish) -> boolish - # If true, does not use past transcription (if any) as initial prompt for the decoder. + # If `true`, does not use past transcription (if any) as initial prompt for the decoder. # def no_context: () -> (true | false) def single_segment=: (boolish) -> boolish - # If true, forces single segment output (useful for streaming). + # If `true`, forces single segment output (useful for streaming). # def single_segment: () -> (true | false) def print_special=: (boolish) -> boolish - # If true, prints special tokens (e.g. , , , etc.). + # If `true`, prints special tokens (e.g. , , , etc.). # def print_special: () -> (true | false) def print_progress=: (boolish) -> boolish - # If true, prints progress information. + # If `true`, prints progress information. # def print_progress: () -> (true | false) def print_realtime=: (boolish) -> boolish - # If true, prints results from within whisper.cpp. (avoid it, use callback instead) + # If `true`, prints results from within whisper.cpp. (avoid it, use callback instead) # def print_realtime: () -> (true | false) - # If true, prints timestamps for each text segment when printing realtime. + # If `true`, prints timestamps for each text segment when printing realtime. # def print_timestamps=: (boolish) -> boolish @@ -253,19 +253,19 @@ module Whisper def suppress_blank=: (boolish) -> boolish - # If true, suppresses blank outputs. + # If `true`, suppresses blank outputs. # def suppress_blank: () -> (true | false) def suppress_nst=: (boolish) -> boolish - # If true, suppresses non-speech-tokens. + # If `true`, suppresses non-speech-tokens. # def suppress_nst: () -> (true | false) def token_timestamps=: (boolish) -> boolish - # If true, enables token-level timestamps. + # If `true`, enables token-level timestamps. # def token_timestamps: () -> (true | false) @@ -277,16 +277,16 @@ module Whisper def split_on_word=: (boolish) -> boolish - # If true, split on word rather than on token (when used with max_len). + # If `true`, split on word rather than on token (when used with max_len). # def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS def carry_initial_prompt=: (boolish) -> boolish - # Tokens to provide to the whisper decoder as initial prompt - # these are prepended to any existing text context from a previous call - # use whisper_tokenize() to convert text to tokens. + # Tokens to provide to the whisper decoder as initial prompt + # these are prepended to any existing text context from a previous call + # use whisper_tokenize() to convert text to tokens. # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) @@ -294,7 +294,7 @@ module Whisper def diarize=: (boolish) -> boolish - # If true, enables diarization. + # If `true`, enables diarization. # def diarize: () -> (true | false) @@ -423,7 +423,7 @@ module Whisper # def on_new_segment: { (Segment) -> void } -> void - # Hook called on progress update. Yields each progress Integer between 0 and 100. + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. # def on_progress: { (Integer progress) -> void } -> void @@ -431,7 +431,7 @@ module Whisper # def on_encoder_begin: { () -> void } -> void - # Call block to determine whether abort or not. Return +true+ when you want to abort. + # Call block to determine whether abort or not. Return `true` when you want to abort. # # params.abort_on do # if some_condition @@ -504,13 +504,13 @@ module Whisper # Yields each Whisper::Token: # - # whisper.each_segment.first.each_token do |token| - # p token - # end + # whisper.each_segment.first.each_token do |token| + # p token + # end # - # Returns an Enumerator if no block is given: + # Returns an `Enumerator` if no block is given: # - # whisper.each_segment.first.each_token.to_a # => [#, ...] + # whisper.each_segment.first.each_token.to_a # => [#, ...] # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] @@ -518,7 +518,7 @@ module Whisper def to_webvtt_cue: () -> String - # Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next + # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # # whisper.each_segment do |segment| # segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:} @@ -569,7 +569,7 @@ module Whisper # [EXPERIMENTAL] Token-level timestamps with DTW # - # Do not use if you haven't computed token-level timestamps with dtw. + # Do not use if you haven't computed token-level timestamps with dtw. # Roughly corresponds to the moment in audio in which the token was output. # def t_dtw: () -> Integer @@ -580,14 +580,14 @@ module Whisper # Start time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def start_time: () -> Integer # End time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def end_time: () -> Integer