ruby : fix dangling pointers, memory leak, and SEGV on parallel transcription (#3715)
* Prevent dangling pointers * Use proper free function * Free callback containers * Set default log callback when nil is passed to log_set * Raise error if callbacks set when parallel transcription * Bump version to 1.3.7 * Make tests follow spec change * Add note on parallel transcription and callbacks * Update signature of Whisper.log_set [skip ci]
This commit is contained in:
parent
9386f23940
commit
76684141a5
|
|
@ -202,6 +202,8 @@ whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
|
|||
|
||||
Note that transcription occasionally might be low accuracy when it works in parallel.
|
||||
|
||||
If n_processors is greater than 1, you cannot set any callbacks including new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, and log_callback set by Whisper.log_set.
|
||||
|
||||
### Segments ###
|
||||
|
||||
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void *
|
|||
return;
|
||||
}
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (NIL_P(log_callback)) {
|
||||
return;
|
||||
}
|
||||
|
||||
VALUE udata = rb_iv_get(mWhisper, "user_data");
|
||||
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
|
||||
}
|
||||
|
|
@ -129,10 +133,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
|
|||
rb_iv_set(self, "log_callback", log_callback);
|
||||
rb_iv_set(self, "user_data", user_data);
|
||||
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
if (!NIL_P(log_callback)) {
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
}
|
||||
|
||||
whisper_log_set(ruby_whisper_log_callback, NULL);
|
||||
if (NIL_P(log_callback)) {
|
||||
whisper_log_set(NULL, NULL);
|
||||
} else {
|
||||
whisper_log_set(ruby_whisper_log_callback, NULL);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#define RUBY_WHISPER_H
|
||||
|
||||
#include <ruby.h>
|
||||
#include <ruby/util.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "whisper.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
|
|||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
|
||||
|
||||
ID transcribe_option_names[1];
|
||||
|
||||
|
|
@ -436,7 +436,7 @@ full_body(VALUE rb_args)
|
|||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context);
|
||||
prepare_transcription(rwp, args->context, 1);
|
||||
int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples);
|
||||
|
||||
return INT2NUM(result);
|
||||
|
|
@ -487,7 +487,7 @@ full_parallel_body(VALUE rb_args)
|
|||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context);
|
||||
prepare_transcription(rwp, args->context, args->n_processors);
|
||||
int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors);
|
||||
|
||||
return INT2NUM(result);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
|
||||
extern VALUE cParams;
|
||||
extern VALUE cVADParams;
|
||||
extern VALUE mWhisper;
|
||||
|
||||
extern ID id_call;
|
||||
|
||||
|
|
@ -186,6 +187,35 @@ static bool abort_callback(void * user_data) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void
|
||||
check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors)
|
||||
{
|
||||
if (n_processors == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
|
||||
rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
|
||||
rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
|
||||
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
||||
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (!NIL_P(log_callback)) {
|
||||
rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
|
||||
rwp->new_segment_callback_container->context = context;
|
||||
|
|
@ -219,9 +249,13 @@ static void set_vad_params(ruby_whisper_params *rwp)
|
|||
rwp->params.vad_params = rwvp->params;
|
||||
}
|
||||
|
||||
/*
|
||||
TODO: Set abort callback to trap SIGINT and SIGTERM
|
||||
*/
|
||||
void
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context)
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
|
||||
{
|
||||
check_thread_safety(rwp, context, n_processors);
|
||||
register_callbacks(rwp, context);
|
||||
set_vad_params(rwp);
|
||||
}
|
||||
|
|
@ -240,6 +274,20 @@ rb_whisper_params_mark(void *p)
|
|||
void
|
||||
ruby_whisper_params_free(ruby_whisper_params *rwp)
|
||||
{
|
||||
if (rwp->params.language) {
|
||||
ruby_xfree((void *)rwp->params.language);
|
||||
}
|
||||
if (rwp->params.initial_prompt) {
|
||||
ruby_xfree((void *)rwp->params.initial_prompt);
|
||||
}
|
||||
if (rwp->params.vad_model_path) {
|
||||
ruby_xfree((void *)rwp->params.vad_model_path);
|
||||
}
|
||||
|
||||
xfree(rwp->new_segment_callback_container);
|
||||
xfree(rwp->progress_callback_container);
|
||||
xfree(rwp->encoder_begin_callback_container);
|
||||
xfree(rwp->abort_callback_container);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -248,7 +296,7 @@ rb_whisper_params_free(void *p)
|
|||
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
|
||||
// How to free user_data and callback only when not referred to by others?
|
||||
ruby_whisper_params_free(rwp);
|
||||
free(rwp);
|
||||
xfree(rwp);
|
||||
}
|
||||
|
||||
static size_t
|
||||
|
|
@ -276,6 +324,15 @@ ruby_whisper_params_allocate(VALUE klass)
|
|||
ruby_whisper_params *rwp;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
if (rwp->params.language != NULL) {
|
||||
rwp->params.language = ruby_strdup(rwp->params.language);
|
||||
}
|
||||
if (rwp->params.initial_prompt != NULL) {
|
||||
rwp->params.initial_prompt = ruby_strdup(rwp->params.initial_prompt);
|
||||
}
|
||||
if (rwp->params.vad_model_path != NULL) {
|
||||
rwp->params.vad_model_path = ruby_strdup(rwp->params.vad_model_path);
|
||||
}
|
||||
rwp->diarize = false;
|
||||
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
|
||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||
|
|
@ -296,10 +353,12 @@ ruby_whisper_params_set_language(VALUE self, VALUE value)
|
|||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_xfree((void *)rwp->params.language);
|
||||
rwp->params.language = NULL;
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->params.language = "auto";
|
||||
rwp->params.language = ruby_strdup("auto");
|
||||
} else {
|
||||
rwp->params.language = StringValueCStr(value);
|
||||
rwp->params.language = ruby_strdup(StringValueCStr(value));
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
|
@ -608,7 +667,13 @@ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
|
|||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
rwp->params.initial_prompt = StringValueCStr(value);
|
||||
ruby_xfree((void *)rwp->params.initial_prompt);
|
||||
rwp->params.initial_prompt = NULL;
|
||||
if (NIL_P(value)) {
|
||||
rwp->params.initial_prompt = NULL;
|
||||
} else {
|
||||
rwp->params.initial_prompt = ruby_strdup(StringValueCStr(value));
|
||||
}
|
||||
return value;
|
||||
}
|
||||
/*
|
||||
|
|
@ -1103,12 +1168,14 @@ ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value)
|
|||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_xfree((void *)rwp->params.vad_model_path);
|
||||
rwp->params.vad_model_path = NULL;
|
||||
if (NIL_P(value)) {
|
||||
rwp->params.vad_model_path = NULL;
|
||||
return value;
|
||||
}
|
||||
VALUE path = ruby_whisper_normalize_model_path(value);
|
||||
rwp->params.vad_model_path = StringValueCStr(path);
|
||||
rwp->params.vad_model_path = ruby_strdup(StringValueCStr(path));
|
||||
return value;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ extern ID id_to_path;
|
|||
extern ID transcribe_option_names[1];
|
||||
|
||||
extern void
|
||||
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
|
||||
prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
|
|
@ -73,7 +73,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
// }
|
||||
|
||||
prepare_transcription(rwp, &self);
|
||||
prepare_transcription(rwp, &self, n_processors);
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ module Whisper
|
|||
def self.lang_id: (string name) -> Integer
|
||||
def self.lang_str: (Integer id) -> String
|
||||
def self.lang_str_full: (Integer id) -> String
|
||||
def self.log_set: (log_callback, Object? user_data) -> log_callback
|
||||
def self.log_set: (log_callback?, Object? user_data) -> log_callback
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
class Context
|
||||
|
|
@ -52,6 +52,9 @@ module Whisper
|
|||
# puts text
|
||||
# end
|
||||
#
|
||||
# If n_processors is greater than 1, you cannot set any callbacks including
|
||||
# new_segment_callback, progress_callback, encoder_begin_callback, abort_callback,
|
||||
# and log_callback set by Whisper.log_set
|
||||
def transcribe: (path, Params, ?n_processors: Integer) -> self
|
||||
| (path, Params, ?n_processors: Integer) { (String) -> void } -> self
|
||||
|
||||
|
|
@ -129,6 +132,9 @@ module Whisper
|
|||
# It seems this approach can offer some speedup in some cases.
|
||||
# However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
#
|
||||
# If n_processors is greater than 1, you cannot set any callbacks including
|
||||
# new_segment_callback, progress_callback, encoder_begin_callback, abort_callback,
|
||||
# and log_callback set by Whisper.log_set
|
||||
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ class TestParams < TestBase
|
|||
def test_language
|
||||
@params.language = "en"
|
||||
assert_equal @params.language, "en"
|
||||
GC.compact
|
||||
assert_equal @params.language, "en"
|
||||
@params.language = "auto"
|
||||
assert_equal @params.language, "auto"
|
||||
end
|
||||
|
|
|
|||
|
|
@ -43,9 +43,20 @@ class TestWhisper < TestBase
|
|||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
|
||||
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
|
||||
assert_match(/what you can do for your country/i, text)
|
||||
}
|
||||
without_log_callback do
|
||||
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
|
||||
assert_match(/what you can do for your country/i, text)
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def without_log_callback
|
||||
Whisper.log_set nil, nil
|
||||
yield
|
||||
ensure
|
||||
Whisper.log_set ->(level, buffer, user_data) {}, nil
|
||||
end
|
||||
|
||||
sub_test_case "After transcription" do
|
||||
|
|
@ -229,7 +240,9 @@ class TestWhisper < TestBase
|
|||
|
||||
def test_full_parallel
|
||||
nprocessors = 2
|
||||
@whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
|
||||
without_log_callback do
|
||||
@whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
|
||||
end
|
||||
|
||||
assert_equal nprocessors, @whisper.full_n_segments
|
||||
text = @whisper.each_segment.collect(&:text).join
|
||||
|
|
@ -240,7 +253,9 @@ class TestWhisper < TestBase
|
|||
def test_full_parallel_with_memory_view
|
||||
nprocessors = 2
|
||||
samples = JFKReader.new(AUDIO)
|
||||
@whisper.full_parallel(@params, samples, nil, nprocessors)
|
||||
without_log_callback do
|
||||
@whisper.full_parallel(@params, samples, nil, nprocessors)
|
||||
end
|
||||
|
||||
assert_equal nprocessors, @whisper.full_n_segments
|
||||
text = @whisper.each_segment.collect(&:text).join
|
||||
|
|
@ -259,7 +274,9 @@ class TestWhisper < TestBase
|
|||
|
||||
def test_full_parallel_without_length
|
||||
nprocessors = 2
|
||||
@whisper.full_parallel(@params, @samples, nil, nprocessors)
|
||||
without_log_callback do
|
||||
@whisper.full_parallel(@params, @samples, nil, nprocessors)
|
||||
end
|
||||
|
||||
assert_equal nprocessors, @whisper.full_n_segments
|
||||
text = @whisper.each_segment.collect(&:text).join
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ require_relative "extsources"
|
|||
Gem::Specification.new do |s|
|
||||
s.name = "whispercpp"
|
||||
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
|
||||
s.version = '1.3.6'
|
||||
s.version = '1.3.7'
|
||||
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
|
||||
s.email = 'todd.fisher@gmail.com'
|
||||
s.extra_rdoc_files = ['LICENSE', 'README.md']
|
||||
|
|
|
|||
Loading…
Reference in New Issue