ruby : add support for Parakeet (#3885)
* Add Whisper::Parakeet::Params * Add tests for Parakeet::Params * Remove unused variabel * Add callbacks to Parakeet::Params * Group callback and user_data params * Undefine local macros * Define GetParakeetParams * Remove unused variable * Use ITERATE_CALLBACK_PARAMS * Use ITERATE_CALLBACK_PARAMS instead of ITERATE_USER_DATA_PARAMS * Fix memsize * Remove unnecessary macros * Simplify params registration * Define Parakeet * Add hook methods to Parakeet::Params * Fix typo * Check callback container in GetParakeetParams * Reduce if * Free parakeet_full_params * Implement Parakeet::Context#initialize * Add TestParakeetContext * Add Parakeet::Segment * Prevent double-free * Add Parakeet::Context#transcribe * Add Parakeet::Context#each_segment * Define Parakeet::Segment attributes * Define Parakeet::Segment#deconstruct_keys * Add tests for Parakeet::Segment#deconstruct_keys * Run Parakeet::Context#transcribe without GVL * Make it to abort for Parakeet * Add Parakeet.log_set * Define Parakeet::Token * Define Parakeet::Segment#each_token * Implement some hooks of Parakeet::Params * Convert int to VALUE * Implement hooks for Parakeet * Implement Parakeet::Context#full * Add tests for Parakeet::Context#full * Add Parakeet to RBS * Fix ruby_whisper_parakeet_params_free * Free ruby_whisper_parakeet_context * Add tests for hooks * Add Parakeet section to README * Add more attributes of Parakeet::Context * Add tests for Parakeet::Context's attributes * Update RBS * Register parakeet-tdt-0.6b-v3 * Narrow scope of log constants * Extract activate and deactivate of log_queue * Make start_log_callback_thread private * Don't call start_log_callback_thread unncecessarilly * Early return from log_queue_enqueue when not active * Gropu log_queue members * is_active -> is_open * Fix English * Share parakeet full body function * ruby_whisper_parakeet_abort_callback_user_data -> ruby_whisper_abort_callback_user_data * NULL check for callback containers * Fix Parakeet.log_set * Omit Parakeet tests on CI * Extract Whisper::LogSettable * Join log callback thread in a log queue function * Revert Join log callback thread in a log queue function * Extract output methods to modules * Move Parakeet init functions into init_parakeet() * Add output methods to Parakeet classes * Add Parakeet's output methods to RBS * Use Whisper::Output in RBS * Add LogSettable to RBS * Fix module Token -> class Token * Add Parakeet::Model * Add test for Parakeet::Model * Add Parakeet::Model to RBS * Move position of Parakeet::Model in RBS * Parakeet -> TestBase::Parakeet * Add Parakeet::Context#model in RBS * Add Whisper::Output * Fix nil check * Define ruby_whisper_parakeet_model_memsize * Fix order of declaration in ruby_whisper_parakeet_model_get_xxx * Define Parakeet.system_info_str * Add test for Parakeet.system_info_str * Add signature of Parakeet.system_info_str * Define Parakeet::VERSION * Add test for Parakeet::VERSION * Add signature of Parakeet::VERSION * Add Parakeet::Context::Params * Make Parakeet::Context.new accept Context::Params * Add test for Parakeet::Context.new with Context::Params * Update RBS * Remove params from Parakeet::Params which are moved from whisper_parakeet_full_params * Remove tests for removed params * Make Parakeet tests follow original behavior changes * Add Parakeet model shortcuts * Alloc token data in factory instead of alloc func * Fix variable name * Update RBS * Refactor log settable module * Use log settable for Whisper * Address deadlock * Make test follow change of log queue implementation * Refactor to make abort callback use the same way to parakeet's way * Remove redundant structs * Fix test name * Fix README * Add missing parallel transcription * Fix test for parakeet info * Remove removed params * Wait for logs dequeued * Fix instance variable name * Load etc feature * Remove unnecessary comment * Remove unnecessary thread safety check * Remove outdated comment * Skip downloading model if cache exists * Change Hugging Face URI for Parakeet models * Bump required Ruby version to 3.3 * Fix English
This commit is contained in:
parent
9efddafb91
commit
0d14756929
|
|
@ -27,6 +27,6 @@ jobs:
|
|||
steps:
|
||||
- uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0
|
||||
with:
|
||||
ruby-version: '3.2'
|
||||
ruby-version: '3.3'
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6
|
||||
- run: rake test
|
||||
|
|
|
|||
|
|
@ -396,6 +396,37 @@ whisper
|
|||
.full(Whisper::Params.new, samples)
|
||||
```
|
||||
|
||||
### Parakeet ###
|
||||
|
||||
whispercpp gem now supports NVIDIA's ASR model Parakeet.
|
||||
|
||||
If you want to use Parakeet instead of Whisper, the API should feel familiar.
|
||||
In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way.
|
||||
|
||||
```ruby
|
||||
require "whisper"
|
||||
|
||||
# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem.
|
||||
Parakeet = Whisper::Parakeet
|
||||
|
||||
parakeet = Parakeet::Context.new("path/to/model")
|
||||
|
||||
params = Parakeet::Params.new(
|
||||
no_context: true
|
||||
)
|
||||
|
||||
parakeet
|
||||
.transcribe("path/to/audio.wav", params)
|
||||
.each_segment do |segment|
|
||||
puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}"
|
||||
end
|
||||
```
|
||||
|
||||
The main differences are:
|
||||
|
||||
* Namespace is `Whisper::Parakeet`.
|
||||
* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks.
|
||||
|
||||
Custom context params
|
||||
---------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,21 @@ else
|
|||
end
|
||||
end
|
||||
|
||||
TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin"
|
||||
TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin"))
|
||||
TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d")
|
||||
directory TEST_PARAKEET_MODEL_DIR
|
||||
if File.exist? TEST_PARAKEET_MODEL_SRC
|
||||
file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t|
|
||||
symlink t.source, t.name
|
||||
end
|
||||
else
|
||||
require "open-uri"
|
||||
file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t|
|
||||
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read
|
||||
end
|
||||
end
|
||||
|
||||
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
|
||||
chdir "test/jfk_reader" do
|
||||
|
|
@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
|
|||
end
|
||||
CLEAN.include TEST_MEMORY_VIEW
|
||||
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,29 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE mLogSettable;
|
||||
VALUE mVAD;
|
||||
VALUE mParakeet;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE cVADContext;
|
||||
VALUE cVADParams;
|
||||
VALUE cVADSegments;
|
||||
VALUE cVADSegment;
|
||||
VALUE cParakeetContext;
|
||||
VALUE cParakeetContextParams;
|
||||
VALUE cParakeetParams;
|
||||
VALUE cParakeetSegment;
|
||||
VALUE cParakeetModel;
|
||||
VALUE eError;
|
||||
|
||||
VALUE cSegment;
|
||||
VALUE cToken;
|
||||
VALUE cModel;
|
||||
|
||||
VALUE mOutputContext;
|
||||
VALUE mOutputSegment;
|
||||
|
||||
ID id_to_s;
|
||||
ID id_call;
|
||||
ID id___method__;
|
||||
|
|
@ -27,9 +37,11 @@ ID id_pre_converted_models;
|
|||
ID id_coreml_compiled_models;
|
||||
ID id_cache;
|
||||
ID id_n_processors;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
static bool is_ruby_log_callback_present = false;
|
||||
ID id_extended;
|
||||
ID id_start_log_callback_thread;
|
||||
ID id_log_callback_thread;
|
||||
ID id_alive_p;
|
||||
ID id_join;
|
||||
|
||||
// High level API
|
||||
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
|
||||
|
|
@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD);
|
|||
extern void init_ruby_whisper_vad_context(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_vad_segment(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_vad_segments(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_parakeet(VALUE *mWhisper);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
static ruby_whisper_log_queue whisper_log_queue;
|
||||
|
||||
LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set)
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_max_id -> Integer
|
||||
|
|
@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) {
|
|||
return rb_str_new2(whisper_print_system_info());
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
|
||||
is_log_callback_finalized = true;
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int level;
|
||||
const char * buffer;
|
||||
} call_log_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_log_callbacks(void *v_args) {
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (NIL_P(log_callback)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
call_log_callbacks_args *args = (call_log_callbacks_args *)v_args;
|
||||
VALUE user_data = rb_iv_get(mWhisper, "user_data");
|
||||
rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
|
||||
if (is_log_callback_finalized) {
|
||||
return;
|
||||
}
|
||||
if (!is_ruby_log_callback_present) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_log_callbacks_args args = {
|
||||
level,
|
||||
buffer,
|
||||
};
|
||||
if (ruby_thread_has_gvl_p()) {
|
||||
call_log_callbacks((void *)&args);
|
||||
} else {
|
||||
rb_thread_call_with_gvl(call_log_callbacks, (void *)&args);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
|
||||
*/
|
||||
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
|
||||
VALUE old_callback = rb_iv_get(self, "log_callback");
|
||||
if (!NIL_P(old_callback)) {
|
||||
rb_undefine_finalizer(old_callback);
|
||||
}
|
||||
|
||||
rb_iv_set(self, "log_callback", log_callback);
|
||||
rb_iv_set(self, "user_data", user_data);
|
||||
|
||||
if (!NIL_P(log_callback)) {
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
}
|
||||
|
||||
if (NIL_P(log_callback)) {
|
||||
whisper_log_set(NULL, NULL);
|
||||
is_ruby_log_callback_present = false;
|
||||
} else {
|
||||
whisper_log_set(ruby_whisper_log_callback, NULL);
|
||||
is_ruby_log_callback_present = true;
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
id_to_s = rb_intern("to_s");
|
||||
id_call = rb_intern("call");
|
||||
|
|
@ -189,9 +133,19 @@ void Init_whisper() {
|
|||
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
|
||||
id_cache = rb_intern("cache");
|
||||
id_n_processors = rb_intern("n_processors");
|
||||
id_extended = rb_intern("extended");
|
||||
id_start_log_callback_thread = rb_intern("start_log_callback_thread");
|
||||
id_log_callback_thread = rb_intern("@log_callback_thread");
|
||||
id_alive_p = rb_intern("alive?");
|
||||
id_join = rb_intern("join");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
rb_require("whisper/log_settable");
|
||||
mLogSettable = rb_path2class("Whisper::LogSettable");
|
||||
mVAD = rb_define_module_under(mWhisper, "VAD");
|
||||
rb_require("whisper/output");
|
||||
mOutputContext = rb_path2class("Whisper::Output::Context");
|
||||
mOutputSegment = rb_path2class("Whisper::Output::Segment");
|
||||
|
||||
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
|
||||
|
|
@ -222,8 +176,8 @@ void Init_whisper() {
|
|||
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
|
||||
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
|
||||
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
|
||||
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
|
||||
|
||||
LOG_SETTABLE_INIT(whisper_log_queue, mWhisper)
|
||||
|
||||
cContext = init_ruby_whisper_context(&mWhisper);
|
||||
init_ruby_whisper_context_params(&cContext);
|
||||
|
|
@ -236,8 +190,10 @@ void Init_whisper() {
|
|||
init_ruby_whisper_vad_segment(&mVAD);
|
||||
init_ruby_whisper_vad_segments(&mVAD);
|
||||
init_ruby_whisper_vad_context(&mVAD);
|
||||
init_ruby_whisper_parakeet(&mWhisper);
|
||||
|
||||
rb_require("whisper/context");
|
||||
rb_require("whisper/segment");
|
||||
rb_require("whisper/model/uri");
|
||||
|
||||
rb_include_module(cContext, mOutputContext);
|
||||
rb_include_module(cSegment, mOutputSegment);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@
|
|||
#include <ruby/version.h>
|
||||
#include <ruby/util.h>
|
||||
#include <ruby/thread.h>
|
||||
#include <ruby/thread_native.h>
|
||||
#include <ruby/atomic.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "whisper.h"
|
||||
#include "parakeet.h"
|
||||
#include "ruby_whisper_log_settable.h"
|
||||
|
||||
#if RUBY_API_VERSION_MAJOR < 4
|
||||
// Exists but not declared as public API
|
||||
|
|
@ -20,13 +24,28 @@ typedef struct {
|
|||
VALUE callbacks;
|
||||
} ruby_whisper_callback_container;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE user_data;
|
||||
VALUE callback;
|
||||
VALUE callbacks;
|
||||
bool is_interrupted;
|
||||
} ruby_whisper_abort_callback_container;
|
||||
typedef struct ruby_whisper_abort_callback_user_data {
|
||||
volatile rb_atomic_t is_interrupted;
|
||||
ruby_whisper_callback_container *callback_container;
|
||||
} ruby_whisper_abort_callback_user_data;
|
||||
|
||||
typedef struct ruby_whisper_log {
|
||||
enum ggml_log_level level;
|
||||
char *text;
|
||||
size_t length;
|
||||
size_t capacity;
|
||||
} ruby_whisper_log;
|
||||
|
||||
typedef struct ruby_whisper_log_queue {
|
||||
rb_nativethread_lock_t lock;
|
||||
rb_nativethread_cond_t cond;
|
||||
bool is_open;
|
||||
|
||||
size_t head;
|
||||
size_t tail;
|
||||
size_t size;
|
||||
ruby_whisper_log *logs;
|
||||
} ruby_whisper_log_queue;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_context *context;
|
||||
|
|
@ -42,7 +61,7 @@ typedef struct {
|
|||
ruby_whisper_callback_container *new_segment_callback_container;
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
VALUE vad_params;
|
||||
} ruby_whisper_params;
|
||||
|
||||
|
|
@ -84,6 +103,63 @@ typedef struct parsed_samples_t {
|
|||
bool memview_exported;
|
||||
} parsed_samples_t;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} ruby_whisper_full_args;
|
||||
|
||||
typedef struct ruby_whisper_full_parallel_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
} ruby_whisper_full_parallel_args;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_full_params params;
|
||||
ruby_whisper_callback_container *new_segment_callback_container;
|
||||
ruby_whisper_callback_container *new_token_callback_container;
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
} ruby_whisper_parakeet_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context_params params;
|
||||
} ruby_whisper_parakeet_context_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context *context;
|
||||
} ruby_whisper_parakeet_context;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
int index;
|
||||
} ruby_whisper_parakeet_segment;
|
||||
|
||||
typedef struct {
|
||||
parakeet_token_data *token_data;
|
||||
VALUE text;
|
||||
} ruby_whisper_parakeet_token;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
} ruby_whisper_parakeet_model;
|
||||
|
||||
extern ID id_extended;
|
||||
extern ID id_log_callback_thread;
|
||||
extern ID id_start_log_callback_thread;
|
||||
extern ID id_alive_p;
|
||||
extern ID id_join;
|
||||
extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text);
|
||||
extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue);
|
||||
|
||||
#define GetContext(obj, rw) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
|
||||
if ((rw)->context == NULL) { \
|
||||
|
|
@ -120,4 +196,47 @@ typedef struct parsed_samples_t {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContextParams(obj, rwpcp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContext(obj, rwpc) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \
|
||||
if ((rwpc)->context == NULL) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetParams(obj, rwpp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \
|
||||
if (!(rwpp)->new_segment_callback_container || \
|
||||
!(rwpp)->new_token_callback_container || \
|
||||
!(rwpp)->progress_callback_container || \
|
||||
!(rwpp)->encoder_begin_callback_container || \
|
||||
!(rwpp)->abort_callback_container) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetSegment(obj, rwps) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \
|
||||
if (!(rwps)->context) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetToken(obj, rwpt) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \
|
||||
if (!(rwpt)->token_data) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetModel(obj, rwpm) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \
|
||||
if (NIL_P((rwpm)->context)) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
|
|||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
|
||||
|
||||
ID transcribe_option_names[1];
|
||||
|
||||
|
|
@ -38,21 +38,6 @@ typedef struct fill_samples_args {
|
|||
int n_samples;
|
||||
} fill_samples_args;
|
||||
|
||||
typedef struct full_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} full_args;
|
||||
|
||||
typedef struct full_parallel_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
} full_parallel_args;
|
||||
|
||||
typedef struct full_without_gvl_args {
|
||||
struct whisper_context *context;
|
||||
struct whisper_full_params *params;
|
||||
|
|
@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args {
|
|||
} full_parallel_without_gvl_args;
|
||||
|
||||
typedef struct full_ubf_args {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
|
||||
} full_ubf_args;
|
||||
|
||||
static void
|
||||
|
|
@ -379,7 +364,7 @@ fill_samples(VALUE rb_args)
|
|||
return Qnil;
|
||||
}
|
||||
|
||||
struct parsed_samples_t
|
||||
parsed_samples_t
|
||||
parse_samples(VALUE *samples, VALUE *n_samples)
|
||||
{
|
||||
bool memview_available = rb_memory_view_available_p(*samples);
|
||||
|
|
@ -480,20 +465,24 @@ full_ubf(void *rb_args)
|
|||
{
|
||||
full_ubf_args *args = (full_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
VALUE
|
||||
full_body(VALUE rb_args)
|
||||
{
|
||||
full_args *args = (full_args *)rb_args;
|
||||
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, 1);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
|
||||
|
||||
struct full_without_gvl_args full_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -503,7 +492,7 @@ full_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_without_gvl_args.result);
|
||||
|
|
@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
|||
VALUE n_samples = argc == 2 ? Qnil : argv[2];
|
||||
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
full_args args = {
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
|
|
@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
VALUE
|
||||
full_parallel_body(VALUE rb_args)
|
||||
{
|
||||
full_parallel_args *args = (full_parallel_args *)rb_args;
|
||||
ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args;
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, args->n_processors);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
|
||||
|
||||
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_parallel_without_gvl_args.result);
|
||||
|
|
@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
|||
break;
|
||||
}
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
const full_parallel_args args = {
|
||||
const ruby_whisper_full_parallel_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define LOG_QUEUE_CAPACITY 256
|
||||
#define LOG_DEFAULT_CAPACITY 1024
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_initialize(&log_queue->lock);
|
||||
rb_native_cond_initialize(&log_queue->cond);
|
||||
log_queue->head = 0;
|
||||
log_queue->tail = 0;
|
||||
log_queue->size = 0;
|
||||
log_queue->is_open = true;
|
||||
log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY);
|
||||
for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) {
|
||||
// we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL
|
||||
// this doesn't be freed because log queue lives until the end of process
|
||||
char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY);
|
||||
if (!slot) {
|
||||
rb_raise(rb_eRuntimeError, "Could not allocate memory for log text");
|
||||
}
|
||||
ruby_whisper_log log = {
|
||||
0,
|
||||
slot,
|
||||
0,
|
||||
LOG_QUEUE_CAPACITY,
|
||||
};
|
||||
log_queue->logs[i] = log;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
log_queue->is_open = true;
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
log_queue->is_open = false;
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
static size_t
|
||||
calc_enough_cap(size_t len)
|
||||
{
|
||||
size_t quot = len / LOG_DEFAULT_CAPACITY;
|
||||
size_t rem = len % LOG_DEFAULT_CAPACITY;
|
||||
|
||||
return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY;
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
if (!log_queue->is_open) {
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t len = strlen(text);
|
||||
ruby_whisper_log *log = &log_queue->logs[log_queue->head];
|
||||
if (len > log->capacity) {
|
||||
size_t new_cap = calc_enough_cap(len);
|
||||
// we cannot call Ruby API like REALLOC_N because this function is called without GVL
|
||||
char *slot = realloc(log->text, new_cap);
|
||||
if (!slot) {
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return;
|
||||
}
|
||||
log->text = slot;
|
||||
log->capacity = new_cap;
|
||||
}
|
||||
// we cannot call Ruby API like MEMCPY because this function is called without GVL
|
||||
memcpy(log->text, text, sizeof(char) * len);
|
||||
log->length = len;
|
||||
log->level = level;
|
||||
log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY;
|
||||
bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY;
|
||||
log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1;
|
||||
if (is_full) {
|
||||
log_queue->tail = log_queue->head;
|
||||
}
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
static void*
|
||||
ruby_whisper_log_queue_wait(void *args)
|
||||
{
|
||||
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
|
||||
|
||||
rb_native_cond_wait(&log_queue->cond, &log_queue->lock);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_log_queue_wait_ubf(void *args)
|
||||
{
|
||||
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
|
||||
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
enum ggml_log_level level;
|
||||
size_t length;
|
||||
char *text;
|
||||
} log_snapshot;
|
||||
|
||||
VALUE
|
||||
ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
log_snapshot logs[LOG_QUEUE_CAPACITY];
|
||||
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
while (log_queue->size == 0 && log_queue->is_open) {
|
||||
rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue);
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
}
|
||||
|
||||
if (log_queue->size == 0 && !log_queue->is_open) {
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
size_t size = log_queue->size;
|
||||
ruby_whisper_log *log;
|
||||
size_t i;
|
||||
for (i = 0; i < size; i++) {
|
||||
log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY];
|
||||
logs[i].level = log->level;
|
||||
logs[i].length = log->length;
|
||||
char *text = malloc(log->length);
|
||||
if (!text) {
|
||||
logs[i].text = NULL;
|
||||
continue;
|
||||
}
|
||||
logs[i].text = text;
|
||||
memcpy(logs[i].text, log->text, log->length);
|
||||
}
|
||||
log_queue->size = 0;
|
||||
log_queue->tail = log_queue->head;
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
|
||||
VALUE rb_logs = rb_ary_new2(size);
|
||||
VALUE rb_text;
|
||||
for (i = 0; i < size; i++) {
|
||||
if (!logs[i].text) {
|
||||
continue;
|
||||
}
|
||||
rb_text = rb_str_new(logs[i].text, logs[i].length);
|
||||
free(logs[i].text);
|
||||
rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text));
|
||||
}
|
||||
|
||||
return rb_logs;
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
#ifndef RUBY_WHISPER_LOG_SETTABLE_H
|
||||
#define RUBY_WHISPER_LOG_SETTABLE_H
|
||||
|
||||
#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \
|
||||
static VALUE \
|
||||
ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \
|
||||
{ \
|
||||
return ruby_whisper_log_queue_drain(&log_queue); \
|
||||
} \
|
||||
static void \
|
||||
ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \
|
||||
{ \
|
||||
ruby_whisper_log_queue_enqueue(&log_queue, level, text); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \
|
||||
{ \
|
||||
rb_iv_set(self, "@log_callback", log_callback); \
|
||||
rb_iv_set(self, "@log_callback_user_data", user_data); \
|
||||
if (NIL_P(log_callback)) { \
|
||||
log_set(NULL, NULL); \
|
||||
} else { \
|
||||
ruby_whisper_log_queue_open(&log_queue); \
|
||||
rb_funcall((mod), id_start_log_callback_thread, 0); \
|
||||
log_set(ruby_whisper_##log_queue##_log_callback, NULL); \
|
||||
} \
|
||||
return Qnil; \
|
||||
} \
|
||||
static void \
|
||||
ruby_whisper_##log_queue##_end_proc(VALUE args) \
|
||||
{ \
|
||||
ruby_whisper_log_queue_close(&log_queue); \
|
||||
VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \
|
||||
if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \
|
||||
rb_funcall(log_callback_thread, id_join, 0); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LOG_SETTABLE_INIT(log_queue, mod) \
|
||||
ruby_whisper_log_queue_initialize(&log_queue); \
|
||||
rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \
|
||||
rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \
|
||||
rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \
|
||||
rb_extend_object(mod, mLogSettable); \
|
||||
rb_funcall(mLogSettable, id_extended, 1, mod);
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
#include "ruby_whisper.h"
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
extern VALUE mParakeet;
|
||||
extern VALUE mLogSettable;
|
||||
extern VALUE cParakeetContext;
|
||||
extern VALUE cParakeetSegment;
|
||||
extern VALUE mOutputContext;
|
||||
extern VALUE mOutputSegment;
|
||||
|
||||
extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet);
|
||||
extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext);
|
||||
extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet);
|
||||
|
||||
static ruby_whisper_log_queue parakeet_log_queue;
|
||||
|
||||
LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_s_system_info_str(VALUE self)
|
||||
{
|
||||
return rb_str_new2(parakeet_print_system_info());
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet(VALUE *mWhisper)
|
||||
{
|
||||
mParakeet = rb_define_module_under(*mWhisper, "Parakeet");
|
||||
|
||||
rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version()));
|
||||
|
||||
LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet)
|
||||
|
||||
rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0);
|
||||
|
||||
init_ruby_whisper_parakeet_params(&mParakeet);
|
||||
init_ruby_whisper_parakeet_token(&mParakeet);
|
||||
init_ruby_whisper_parakeet_segment(&mParakeet);
|
||||
cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet);
|
||||
init_ruby_whisper_parakeet_context_params(&cParakeetContext);
|
||||
init_ruby_whisper_parakeet_model(&mParakeet);
|
||||
|
||||
rb_include_module(cParakeetContext, mOutputContext);
|
||||
rb_include_module(cParakeetSegment, mOutputSegment);
|
||||
}
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_SEGMENT_ATTRS(ITERATOR) \
|
||||
ITERATOR(get_segment_t0, LONG) \
|
||||
ITERATOR(get_segment_t1, LONG) \
|
||||
ITERATOR(get_segment_text, STRING) \
|
||||
ITERATOR(n_tokens, INT)
|
||||
|
||||
#define ITERATE_TOKEN_ATTRS(ITERATOR) \
|
||||
ITERATOR(get_token_text, STRING) \
|
||||
ITERATOR(get_token_id, INT) \
|
||||
ITERATOR(get_token_p, FLOAT)
|
||||
|
||||
#define VAL_FROM_LONG(v) LONG2NUM(v)
|
||||
#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v)
|
||||
#define VAL_FROM_INT(v) INT2NUM(v)
|
||||
#define VAL_FROM_FLOAT(v) DBL2NUM(v)
|
||||
#define READER(type) VAL_FROM_##type
|
||||
|
||||
extern ID id_to_s;
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
extern ID id_new;
|
||||
|
||||
extern VALUE cParakeetContext;
|
||||
extern VALUE eError;
|
||||
|
||||
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
|
||||
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
|
||||
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
||||
extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples);
|
||||
extern VALUE release_samples(VALUE rb_parsed_args);
|
||||
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
|
||||
extern rb_data_type_t ruby_whisper_parakeet_params_type;
|
||||
extern rb_data_type_t ruby_whisper_parakeet_context_params_type;
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
|
||||
extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context);
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_context_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
||||
if (rwpc->context) {
|
||||
parakeet_free(rwpc->context);
|
||||
rwpc->context = NULL;
|
||||
}
|
||||
xfree(rwpc);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_context_memsize(const void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
||||
if (!rwpc) {
|
||||
return 0;
|
||||
}
|
||||
size_t size = sizeof(*rwpc);
|
||||
return size;
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_context_type = {
|
||||
"ruby_whisper_parakeet_context",
|
||||
{0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
||||
rwpc->context = NULL;
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context **context;
|
||||
char *model_path;
|
||||
struct parakeet_context_params params;
|
||||
} ruby_whisper_parakeet_context_init_args;
|
||||
|
||||
static void*
|
||||
ruby_whisper_parakeet_context_init_without_gvl(void *args)
|
||||
{
|
||||
ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args;
|
||||
*init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
VALUE model_path;
|
||||
VALUE context_params;
|
||||
struct parakeet_context_params params;
|
||||
|
||||
rb_scan_args(argc, argv, "11", &model_path, &context_params);
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
||||
|
||||
model_path = ruby_whisper_normalize_model_path(model_path);
|
||||
if (!rb_respond_to(model_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context");
|
||||
}
|
||||
if (NIL_P(context_params)) {
|
||||
params = parakeet_context_default_params();
|
||||
} else {
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
GetParakeetContextParams(context_params, rwpcp);
|
||||
params = rwpcp->params;
|
||||
}
|
||||
ruby_whisper_parakeet_context_init_args init_args = {
|
||||
&rwpc->context,
|
||||
StringValueCStr(model_path),
|
||||
params,
|
||||
};
|
||||
rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL);
|
||||
if (rwpc->context == NULL) {
|
||||
rb_raise(rb_eRuntimeError, "Failed to load model");
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full_n_segments(VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
|
||||
return INT2NUM(parakeet_full_n_segments(rwpc->context));
|
||||
}
|
||||
|
||||
#define DEF_SEGMENT_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(self, rwpc); \
|
||||
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \
|
||||
}
|
||||
|
||||
ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR)
|
||||
|
||||
#define DEF_TOKEN_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(self, rwpc); \
|
||||
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \
|
||||
}
|
||||
|
||||
ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token));
|
||||
|
||||
return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_each_segment(VALUE self)
|
||||
{
|
||||
if (!rb_block_given_p()) {
|
||||
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
||||
return rb_funcall(self, id_to_enum, 1, method_name);
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
|
||||
const int n_segments = parakeet_full_n_segments(rwpc->context);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context *context;
|
||||
struct parakeet_full_params *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int result;
|
||||
} parakeet_full_without_gvl_args;
|
||||
|
||||
static void*
|
||||
parakeet_full_without_gvl(void *rb_args)
|
||||
{
|
||||
parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args;
|
||||
args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
|
||||
} parakeet_full_ubf_args;
|
||||
|
||||
static void
|
||||
parakeet_full_ubf(void *rb_args)
|
||||
{
|
||||
parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args;
|
||||
|
||||
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_context_full_body(VALUE rb_args)
|
||||
{
|
||||
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(*args->context, rwpc);
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetParams(*args->params, rwpp);
|
||||
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data);
|
||||
|
||||
parakeet_full_without_gvl_args full_without_gvl_args = {
|
||||
rwpc->context,
|
||||
&rwpp->params,
|
||||
args->samples,
|
||||
args->n_samples,
|
||||
0
|
||||
};
|
||||
parakeet_full_ubf_args full_ubf_args = {
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args);
|
||||
|
||||
return INT2NUM(full_without_gvl_args.result);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
if (argc < 2 || argc > 3) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
VALUE n_samples = argc == 2 ? Qnil : argv[2];
|
||||
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
parsed.n_samples,
|
||||
};
|
||||
VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result == 0) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
|
||||
}
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_get_model(VALUE self)
|
||||
{
|
||||
return ruby_whisper_parakeet_model_s_new(self);
|
||||
}
|
||||
|
||||
VALUE
|
||||
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
|
||||
|
||||
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
|
||||
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
|
||||
rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0);
|
||||
rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2);
|
||||
rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0);
|
||||
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
|
||||
rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1);
|
||||
|
||||
#define REGISTER_SEGMENT_ATTR(name, type) \
|
||||
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1);
|
||||
|
||||
ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR)
|
||||
|
||||
#define REGISTER_TOKEN_ATTR(name, type) \
|
||||
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2);
|
||||
|
||||
ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR)
|
||||
|
||||
return cParakeetContext;
|
||||
}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(use_gpu, BOOL) \
|
||||
ITERATOR(gpu_device, INT)
|
||||
|
||||
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
|
||||
#define VAL_TO_BOOL(v) (RTEST(v))
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_TO_INT(v) (NUM2INT(v))
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define WRITER(type) VAL_TO_##type
|
||||
|
||||
#define DEF_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
return READER(type)(rwpcp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
rwpcp->params.name = WRITER(type)(val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name,
|
||||
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS
|
||||
};
|
||||
|
||||
extern VALUE cParakeetContextParams;
|
||||
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_context_params_memsize(const void *p)
|
||||
{
|
||||
if (!p) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_context_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_context_params_type = {
|
||||
"ruby_whisper_parakeet_context_params",
|
||||
{0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,},
|
||||
0, 0,
|
||||
0,
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
rwpcp->params = parakeet_context_default_params();
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values);
|
||||
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) {
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
param_writers[i](self, value);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext)
|
||||
{
|
||||
cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate);
|
||||
|
||||
rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1);
|
||||
|
||||
int i = 0;
|
||||
#define REGISTER_ATTR(name, type) \
|
||||
param_names[i] = rb_intern(#name); \
|
||||
param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \
|
||||
rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \
|
||||
rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \
|
||||
i++;
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(n_vocab) \
|
||||
ITERATOR(n_audio_ctx) \
|
||||
ITERATOR(n_audio_state) \
|
||||
ITERATOR(n_audio_head) \
|
||||
ITERATOR(n_audio_layer) \
|
||||
ITERATOR(n_mels) \
|
||||
ITERATOR(ftype)
|
||||
|
||||
extern rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern VALUE cParakeetModel;
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_model_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p;
|
||||
if (!NIL_P(rwpm->context)) {
|
||||
rb_gc_mark(rwpm->context);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_model_memsize(const void *p)
|
||||
{
|
||||
if (!p) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_model);
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_model_type = {
|
||||
"ruby_whisper_parakeet_model",
|
||||
{ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_model_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_model *rwpm;
|
||||
VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
|
||||
rwpm->context = Qnil;
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_model_s_new(VALUE context)
|
||||
{
|
||||
const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel);
|
||||
ruby_whisper_parakeet_model *rwpm;
|
||||
TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
|
||||
rwpm->context = context;
|
||||
return model;
|
||||
}
|
||||
|
||||
#define DEF_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_model_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_model *rwpm; \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetModel(self, rwpm); \
|
||||
GetParakeetContext(rwpm->context, rwpc); \
|
||||
return INT2NUM(parakeet_model_##name(rwpc->context)); \
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_model(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(name) \
|
||||
rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0);
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
}
|
||||
|
|
@ -0,0 +1,548 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_PARAMS(ITERATOR) \
|
||||
ITERATOR(n_threads, INT) \
|
||||
ITERATOR(offset_ms, INT) \
|
||||
ITERATOR(duration_ms, INT) \
|
||||
ITERATOR(no_context, BOOL) \
|
||||
ITERATOR(audio_ctx, INT)
|
||||
|
||||
#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \
|
||||
ITERATOR(new_segment, DATA) \
|
||||
ITERATOR(new_token, DATA) \
|
||||
ITERATOR(progress, DATA) \
|
||||
ITERATOR(encoder_begin, DATA)
|
||||
|
||||
#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback)
|
||||
#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR)
|
||||
|
||||
#define ITERATE_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATOR(abort_callback)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name,
|
||||
#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name,
|
||||
#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data,
|
||||
ITERATE_PARAMS(DEF_IDX)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA)
|
||||
|
||||
RUBY_WHISPER_PARAKEET_NUM_PARAMS
|
||||
};
|
||||
|
||||
#define VAL_TO_INT(v) (NUM2INT(v))
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_TO_BOOL(v) (RTEST(v))
|
||||
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
|
||||
|
||||
extern VALUE cParakeetParams;
|
||||
extern ID id_call;
|
||||
|
||||
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
|
||||
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
|
||||
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
|
||||
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
int n_new;
|
||||
} call_parakeet_new_segment_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_new_segment_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
const int n_segments = parakeet_full_n_segments_from_state(args->state);
|
||||
for (int i = args->n_new; i > 0; i--) {
|
||||
int i_segment = n_segments - i;
|
||||
VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment);
|
||||
for (int j = 0; j < n_callbacks; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, segment);
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_new_segment_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
n_new,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_context *context;
|
||||
struct parakeet_state *state;
|
||||
const parakeet_token_data *token_data;
|
||||
} call_parakeet_new_token_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_new_token_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args;
|
||||
VALUE token = Qnil;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
if (NIL_P(token)) {
|
||||
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
|
||||
}
|
||||
for (int i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
rb_funcall(cb, id_call, 1, token);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_new_token_callbacks_args args = {
|
||||
container,
|
||||
context,
|
||||
state,
|
||||
token_data,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
int progress;
|
||||
} call_parakeet_progress_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_progress_callback(void *v_args)
|
||||
{
|
||||
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_progress_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
progress,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
bool is_continued;
|
||||
} call_parakeet_encoder_begin_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_encoder_begin_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
call_parakeet_encoder_begin_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
true,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
|
||||
|
||||
return args.is_continued;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_parakeet_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_abort_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
VALUE cb;
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_parakeet_abort_callback(void *user_data)
|
||||
{
|
||||
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
|
||||
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_parakeet_abort_callbacks_args args = {
|
||||
data->callback_container,
|
||||
false,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
|
||||
|
||||
return args.is_interrupted;
|
||||
}
|
||||
|
||||
#define CALLBACK_CONTAINER_NAME(name) name ## _container
|
||||
|
||||
void
|
||||
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
|
||||
{
|
||||
#define PARAM_NAME(name) name
|
||||
#define USER_DATA_NAME(name) name##_user_data
|
||||
#define REGISTER_CALLBACK(name) \
|
||||
if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \
|
||||
rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \
|
||||
rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \
|
||||
}
|
||||
|
||||
ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK)
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) {
|
||||
abort_callback_user_data->callback_container = rwpp->abort_callback_container;
|
||||
}
|
||||
rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback;
|
||||
rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_params_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
|
||||
|
||||
#define MARK_CONTAINER(name) \
|
||||
if (rwpp->name##_container) { \
|
||||
ruby_whisper_callback_container_mark(rwpp->name##_container); \
|
||||
}
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(MARK_CONTAINER)
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_params_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
|
||||
|
||||
#define FREE_CONTAINER(name) \
|
||||
if (rwpp->name##_container) { \
|
||||
xfree(rwpp->name##_container); \
|
||||
}
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(FREE_CONTAINER)
|
||||
|
||||
xfree(rwpp);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_params_memsize(const void *p)
|
||||
{
|
||||
const struct ruby_whisper_parakeet_params *params = p;
|
||||
if (!params) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_params_type = {
|
||||
"ruby_whisper_parakeet_params",
|
||||
{ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define WRITER(type) VAL_TO_##type
|
||||
#define DEF_PARAM_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return READER(type)(rwpp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->params.name = WRITER(type)(val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_CALLBACK_PARAM_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_USER_DATA_PARAM_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_HOOK(name, data) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_on_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
const VALUE blk = rb_block_proc(); \
|
||||
if (NIL_P(rwpp->name##_callback_container->callbacks)) { \
|
||||
rwpp->name##_callback_container->callbacks = rb_ary_new(); \
|
||||
} \
|
||||
rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \
|
||||
return Qnil; \
|
||||
}
|
||||
|
||||
ITERATE_PARAMS(DEF_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR)
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_abort_on(VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetParams(self, rwpp);
|
||||
const VALUE blk = rb_block_proc();
|
||||
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
|
||||
rwpp->abort_callback_container->callbacks = rb_ary_new();
|
||||
}
|
||||
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
|
||||
rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
return obj;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
|
||||
|
||||
#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate();
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(INIT_CONTAINER)
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values);
|
||||
|
||||
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) {
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
param_writers[i](self, value);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_params(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
|
||||
rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate);
|
||||
|
||||
rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1);
|
||||
|
||||
int i = 0;
|
||||
#define REGISTER_PARAM(name) \
|
||||
param_names[i] = rb_intern(#name); \
|
||||
param_writers[i] = ruby_whisper_parakeet_params_set_##name; \
|
||||
rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \
|
||||
rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \
|
||||
i++;
|
||||
|
||||
#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name)
|
||||
#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name)
|
||||
#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data)
|
||||
|
||||
ITERATE_PARAMS(REGISTER_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR)
|
||||
|
||||
#define REGISTER_HOOK(name, data) \
|
||||
rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0);
|
||||
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _)
|
||||
|
||||
rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0);
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(start_time, t0, TIME) \
|
||||
ITERATOR(end_time, t1, TIME) \
|
||||
ITERATOR(text, text, STRING)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name,
|
||||
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
|
||||
};
|
||||
|
||||
#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10))
|
||||
#define VAL_FROM_STRING(v) (rb_str_new2(v))
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define DEF_ATTR(rb_name, c_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_get_##rb_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_segment *rwps; \
|
||||
GetParakeetSegment(self, rwps); \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(rwps->context, rwpc); \
|
||||
return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \
|
||||
}
|
||||
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
extern VALUE cParakeetSegment;
|
||||
extern VALUE sym_start_time;
|
||||
extern VALUE sym_end_time;
|
||||
extern VALUE sym_text;
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token);
|
||||
|
||||
static void
|
||||
rb_whisper_parakeet_segment_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p;
|
||||
rb_gc_mark(rwps->context);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_segment_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p;
|
||||
if (!rwps) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(*rwps);
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_segment_type = {
|
||||
"ruby_whisper_parakeet_segment",
|
||||
{rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_segment_init(VALUE context, int index)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
|
||||
const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment);
|
||||
TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
|
||||
rwps->context = context;
|
||||
rwps->index = index;
|
||||
|
||||
return segment;
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_each_token(VALUE self)
|
||||
{
|
||||
if (!rb_block_given_p()) {
|
||||
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
||||
return rb_funcall(self, id_to_enum, 1, method_name);
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
GetParakeetSegment(self, rwps);
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(rwps->context, rwpc);
|
||||
|
||||
const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
GetParakeetSegment(self, rwps);
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(rwps->context, rwpc);
|
||||
|
||||
VALUE hash = rb_hash_new();
|
||||
long n_keys;
|
||||
if (NIL_P(keys)) {
|
||||
keys = rb_ary_new3(
|
||||
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
|
||||
sym_start_time,
|
||||
sym_end_time,
|
||||
sym_text
|
||||
);
|
||||
n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS;
|
||||
} else {
|
||||
n_keys = RARRAY_LEN(keys);
|
||||
if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) {
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < n_keys; i++) {
|
||||
VALUE key = rb_ary_entry(keys, i);
|
||||
|
||||
#define CHECK_AND_SET_KEY(rb_name, c_name, type) \
|
||||
if (key == sym_##rb_name) { \
|
||||
rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(CHECK_AND_SET_KEY)
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_segment(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(rb_name, c_name, type) \
|
||||
rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0);
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
|
||||
rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0);
|
||||
rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_MEMBERS(ITERATOR) \
|
||||
ITERATOR(id, id, id, id, INT) \
|
||||
ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \
|
||||
ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \
|
||||
ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \
|
||||
ITERATOR(probability, probability, p, p, FLOAT) \
|
||||
ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \
|
||||
ITERATOR(start_time, start_time, start_time, t0, TIME) \
|
||||
ITERATOR(end_time, end_time, end_time, t1, TIME) \
|
||||
ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL)
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(text, text, text, text, STRING)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name,
|
||||
|
||||
ITERATE_MEMBERS(DEF_IDX)
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS,
|
||||
};
|
||||
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_FROM_FLOAT(v) (DBL2NUM(v))
|
||||
#define VAL_FROM_TIME(v) (LONG2NUM(v * 10))
|
||||
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
|
||||
#define VAL_FROM_STRING(v) (rb_str_new2(v))
|
||||
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define MEMBER_NAME(name) name
|
||||
#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_token *rwpt; \
|
||||
GetParakeetToken(self, rwpt); \
|
||||
return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \
|
||||
}
|
||||
|
||||
#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_token *rwpt; \
|
||||
GetParakeetToken(self, rwpt); \
|
||||
return rwpt->p_name; \
|
||||
}
|
||||
|
||||
VALUE cParakeetToken;
|
||||
|
||||
#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key;
|
||||
|
||||
ITERATE_MEMBERS(DEC_ATTR_SYMS)
|
||||
ITERATE_ATTRS(DEC_ATTR_SYMS)
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_token_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
rb_gc_mark(rwpt->text);
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_token_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
if (rwpt->token_data) {
|
||||
xfree(rwpt->token_data);
|
||||
rwpt->token_data = NULL;
|
||||
}
|
||||
xfree(rwpt);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_token_memsize(const void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
if (!rwpt) {
|
||||
return 0;
|
||||
}
|
||||
size_t size = sizeof(*rwpt);
|
||||
if (rwpt->token_data) {
|
||||
size += sizeof(*rwpt->token_data);
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_token_type = {
|
||||
"ruby_whisper_parakeet_token",
|
||||
{ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize},
|
||||
0, 0,
|
||||
0,
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_token_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
|
||||
|
||||
rwpt->token_data = NULL;
|
||||
rwpt->text = Qnil;
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
|
||||
{
|
||||
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
|
||||
|
||||
rwpt->token_data = ALLOC(parakeet_token_data);
|
||||
*rwpt->token_data = *token_data;
|
||||
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
|
||||
{
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
|
||||
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
|
||||
}
|
||||
|
||||
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
|
||||
// Define #text using parakeet_token_to_str or parakeet_token_to_text
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
GetParakeetToken(self, rwpt);
|
||||
|
||||
VALUE hash = rb_hash_new();
|
||||
long n_keys = 0;
|
||||
|
||||
if (NIL_P(keys)) {
|
||||
VALUE attrs[] = {
|
||||
#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key,
|
||||
|
||||
ITERATE_MEMBERS(LIST_SYMS)
|
||||
ITERATE_ATTRS(LIST_SYMS)
|
||||
};
|
||||
keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs);
|
||||
n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS;
|
||||
} else {
|
||||
n_keys = RARRAY_LEN(keys);
|
||||
if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) {
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
for (long i = 0; i < n_keys; i++) {
|
||||
VALUE key = rb_ary_entry(keys, i);
|
||||
|
||||
#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \
|
||||
if (key == sym_##s_key) { \
|
||||
rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \
|
||||
}
|
||||
|
||||
ITERATE_MEMBERS(CHECK_AND_SET_KEY)
|
||||
ITERATE_ATTRS(CHECK_AND_SET_KEY)
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_token(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject);
|
||||
rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
sym_##s_key = ID2SYM(rb_intern(#s_key)); \
|
||||
rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0);
|
||||
|
||||
ITERATE_MEMBERS(REGISTER_ATTR)
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
|
||||
rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
#include "ruby_whisper.h"
|
||||
#include "common-whisper.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_params_type;
|
||||
|
||||
extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args);
|
||||
|
||||
extern ID id_to_path;
|
||||
extern ID id_new;
|
||||
|
||||
extern VALUE eError;
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params)
|
||||
{
|
||||
if (rb_respond_to(audio_path, id_to_path)) {
|
||||
audio_path = rb_funcall(audio_path, id_to_path, 0);
|
||||
}
|
||||
|
||||
std::string fname = StringValueCStr(audio_path);
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
|
||||
if (!read_audio_data(fname, pcmf32, pcmf32s, false)) {
|
||||
rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str());
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetContext(self, rwpc);
|
||||
GetParakeetParams(params, rwpp);
|
||||
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
};
|
||||
VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args);
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result == 0) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -76,8 +76,8 @@ static ID id_vad;
|
|||
static ID id_vad_model_path;
|
||||
static ID id_vad_params;
|
||||
|
||||
static void
|
||||
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
|
||||
void
|
||||
ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
|
|
@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
|
|||
rb_gc_mark(rwc->callbacks);
|
||||
}
|
||||
|
||||
static ruby_whisper_callback_container*
|
||||
rb_whisper_callback_container_allocate() {
|
||||
ruby_whisper_callback_container*
|
||||
ruby_whisper_callback_container_allocate() {
|
||||
ruby_whisper_callback_container *container;
|
||||
container = ALLOC(ruby_whisper_callback_container);
|
||||
container->context = NULL;
|
||||
|
|
@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() {
|
|||
return container;
|
||||
}
|
||||
|
||||
static void
|
||||
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
rb_gc_mark(rwc->user_data);
|
||||
rb_gc_mark(rwc->callback);
|
||||
rb_gc_mark(rwc->callbacks);
|
||||
}
|
||||
|
||||
static ruby_whisper_abort_callback_container*
|
||||
rb_whisper_abort_callback_container_allocate() {
|
||||
ruby_whisper_abort_callback_container *container;
|
||||
container = ALLOC(ruby_whisper_abort_callback_container);
|
||||
container->context = NULL;
|
||||
container->user_data = Qnil;
|
||||
container->callback = Qnil;
|
||||
container->callbacks = Qnil;
|
||||
container->is_interrupted = false;
|
||||
return container;
|
||||
}
|
||||
|
||||
static bool
|
||||
bool
|
||||
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
|
|
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
|
|||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_abort_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_abort_callbacks(void *v_args) {
|
||||
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_abort_callback_container *container = args->container;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
|
|||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (0 == n_callbacks) {
|
||||
return NULL;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
for (int j = 0; j < n_callbacks; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
VALUE result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
|
|||
}
|
||||
|
||||
static bool abort_callback(void * user_data) {
|
||||
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
|
||||
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!ruby_whisper_abort_callback_container_is_present(container)) {
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_abort_callbacks_args args = {
|
||||
container,
|
||||
NULL,
|
||||
data->callback_container,
|
||||
false
|
||||
};
|
||||
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
|
||||
|
|
@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
|
|||
return;
|
||||
}
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription");
|
||||
}
|
||||
// new_segment_callback is called only after multiple threads are joined
|
||||
// progress_callback is not called when parallel
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (!NIL_P(log_callback)) {
|
||||
rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
|
||||
rwp->new_segment_callback_container->context = context;
|
||||
rwp->params.new_segment_callback = new_segment_callback;
|
||||
|
|
@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
|||
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
|
||||
}
|
||||
|
||||
abort_callback_user_data->callback_container = rwp->abort_callback_container;
|
||||
rwp->abort_callback_container->context = context;
|
||||
rwp->params.abort_callback = abort_callback;
|
||||
rwp->abort_callback_container->is_interrupted = false;
|
||||
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
|
||||
}
|
||||
|
||||
static void set_vad_params(ruby_whisper_params *rwp)
|
||||
|
|
@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp)
|
|||
rwp->params.vad_params = rwvp->params;
|
||||
}
|
||||
|
||||
/*
|
||||
TODO: Set abort callback to trap SIGINT and SIGTERM
|
||||
*/
|
||||
void
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
|
||||
{
|
||||
check_thread_safety(rwp, n_processors);
|
||||
register_callbacks(rwp, context);
|
||||
register_callbacks(rwp, context, abort_callback_user_data);
|
||||
set_vad_params(rwp);
|
||||
}
|
||||
|
||||
|
|
@ -421,10 +376,10 @@ void
|
|||
rb_whisper_params_mark(void *p)
|
||||
{
|
||||
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
|
||||
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
|
||||
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
|
||||
rb_gc_mark(rwp->vad_params);
|
||||
}
|
||||
|
||||
|
|
@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass)
|
|||
}
|
||||
rwp->diarize = false;
|
||||
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
|
||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
|
||||
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
static VALUE sym_start_time;
|
||||
static VALUE sym_end_time;
|
||||
static VALUE sym_text;
|
||||
static VALUE sym_no_speech_prob;
|
||||
static VALUE sym_speaker_turn_next;
|
||||
static VALUE sym_n_tokens;
|
||||
VALUE sym_start_time;
|
||||
VALUE sym_end_time;
|
||||
VALUE sym_text;
|
||||
VALUE sym_no_speech_prob;
|
||||
VALUE sym_speaker_turn_next;
|
||||
VALUE sym_n_tokens;
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_type;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ extern ID id_to_path;
|
|||
extern ID transcribe_option_names[1];
|
||||
|
||||
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
|
||||
extern VALUE full_body(VALUE rb_args);
|
||||
extern VALUE full_parallel_body(VALUE rb_args);
|
||||
|
||||
typedef struct{
|
||||
struct whisper_context *context;
|
||||
|
|
@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
} transcribe_ubf_args;
|
||||
|
||||
static void
|
||||
transcribe_ubf(void *rb_args)
|
||||
{
|
||||
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
|
|
@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
// Commented out because it is work in progress
|
||||
// {
|
||||
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
// bool is_aborted = *(bool*)user_data;
|
||||
// return !is_aborted;
|
||||
// };
|
||||
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
// }
|
||||
|
||||
prepare_transcription(rwp, &self, n_processors);
|
||||
|
||||
transcribe_without_gvl_args args = {
|
||||
rw->context,
|
||||
&rwp->params,
|
||||
pcmf32.data(),
|
||||
pcmf32.size(),
|
||||
n_processors,
|
||||
0,
|
||||
};
|
||||
transcribe_ubf_args ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
};
|
||||
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
|
||||
if (args.result != 0) {
|
||||
VALUE rb_result;
|
||||
if (n_processors == 1) {
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
};
|
||||
rb_result = full_body((VALUE)&args);
|
||||
} else {
|
||||
ruby_whisper_full_parallel_args parallel_args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
n_processors,
|
||||
};
|
||||
rb_result = full_parallel_body((VALUE)¶llel_args);
|
||||
}
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
module Whisper
|
||||
class Context
|
||||
def to_srt
|
||||
each_segment.with_index.reduce("") {|srt, (segment, index)|
|
||||
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
|
||||
}
|
||||
end
|
||||
|
||||
def to_webvtt
|
||||
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
|
||||
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
require "mutex_m"
|
||||
|
||||
module Whisper
|
||||
module LogSettable
|
||||
class << self
|
||||
def extended(base)
|
||||
base.extend Mutex_m
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def start_log_callback_thread
|
||||
return if @log_callback_thread&.alive?
|
||||
|
||||
@log_callback_thread = Thread.new {
|
||||
begin
|
||||
while logs = drain_logs
|
||||
begin
|
||||
callback, user_data = synchronize {[@log_callback, @log_callback_user_data]}
|
||||
next if callback.nil?
|
||||
|
||||
logs.each do |(level, text)|
|
||||
callback.call level, text, user_data
|
||||
end
|
||||
rescue => err
|
||||
$stderr.puts err
|
||||
end
|
||||
end
|
||||
rescue => err
|
||||
$stderr.puts err
|
||||
end
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -41,6 +41,8 @@ module Whisper
|
|||
|
||||
def cache
|
||||
path = cache_path
|
||||
return path if cache_path.exist?
|
||||
|
||||
headers = {}
|
||||
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
||||
request @uri, headers
|
||||
|
|
@ -216,8 +218,18 @@ module Whisper
|
|||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
%w[
|
||||
parakeet-tdt-0.6b-v3-f16
|
||||
parakeet-tdt-0.6b-v3-f32
|
||||
parakeet-tdt-0.6b-v3-q4_0
|
||||
parakeet-tdt-0.6b-v3-q4_k
|
||||
parakeet-tdt-0.6b-v3-q8_0
|
||||
].each do |name|
|
||||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
@coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models|
|
||||
next if name.end_with?("-tdrz") || name.start_with?("silero-")
|
||||
next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-")
|
||||
|
||||
if matched = name.match(/\A(?<name>.*)-q\d_\d\z/)
|
||||
name = matched[:name]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
module Whisper
|
||||
module Output
|
||||
module Context
|
||||
def to_srt
|
||||
each_segment.with_index.reduce("") {|srt, (segment, index)|
|
||||
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
|
||||
}
|
||||
end
|
||||
|
||||
def to_webvtt
|
||||
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
|
||||
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
module Segment
|
||||
SRT_ESCAPES = {
|
||||
"&" => "&",
|
||||
"<" => "<",
|
||||
">" => ">",
|
||||
}
|
||||
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
|
||||
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
|
||||
|
||||
def to_srt_cue
|
||||
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
|
||||
end
|
||||
|
||||
def to_webvtt_cue
|
||||
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def time_to_a(time)
|
||||
sec, decimal_part = time.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
[hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
def srt_time(time)
|
||||
"%02d:%02d:%02d,%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def srt_start_time
|
||||
srt_time(start_time)
|
||||
end
|
||||
|
||||
def srt_end_time
|
||||
srt_time(end_time)
|
||||
end
|
||||
|
||||
def srt_text
|
||||
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
|
||||
end
|
||||
|
||||
def webvtt_time(time)
|
||||
"%02d:%02d:%02d.%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def webvtt_start_time
|
||||
webvtt_time(start_time)
|
||||
end
|
||||
|
||||
def webvtt_end_time
|
||||
webvtt_time(end_time)
|
||||
end
|
||||
|
||||
alias webvtt_text srt_text
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
module Whisper
|
||||
class Segment
|
||||
SRT_ESCAPES = {
|
||||
"&" => "&",
|
||||
"<" => "<",
|
||||
">" => ">",
|
||||
}
|
||||
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
|
||||
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
|
||||
|
||||
def to_srt_cue
|
||||
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
|
||||
end
|
||||
|
||||
def to_webvtt_cue
|
||||
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def time_to_a(time)
|
||||
sec, decimal_part = time.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
[hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
def srt_time(time)
|
||||
"%02d:%02d:%02d,%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def srt_start_time
|
||||
srt_time(start_time)
|
||||
end
|
||||
|
||||
def srt_end_time
|
||||
srt_time(end_time)
|
||||
end
|
||||
|
||||
def srt_text
|
||||
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
|
||||
end
|
||||
|
||||
def webvtt_time(time)
|
||||
"%02d:%02d:%02d.%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def webvtt_start_time
|
||||
webvtt_time(start_time)
|
||||
end
|
||||
|
||||
def webvtt_end_time
|
||||
webvtt_time(end_time)
|
||||
end
|
||||
|
||||
alias webvtt_text srt_text
|
||||
end
|
||||
end
|
||||
|
|
@ -40,7 +40,21 @@ module Whisper
|
|||
def self.log_set: (log_callback?, Object? user_data) -> log_callback
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
module Output
|
||||
module Context
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
end
|
||||
|
||||
module Segment
|
||||
def to_srt_cue: () -> String
|
||||
def to_webvtt_cue: () -> String
|
||||
end
|
||||
end
|
||||
|
||||
class Context
|
||||
include Output::Context
|
||||
|
||||
def self.new: (String | path | ::URI::HTTP) -> instance
|
||||
|
||||
# transcribe a single file
|
||||
|
|
@ -139,17 +153,14 @@ module Whisper
|
|||
| (Whisper::Params, _Samples, ?Integer n_samples) -> self
|
||||
| (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
use_gpu: boolish,
|
||||
flash_attn: boolish,
|
||||
gpu_device: Integer,
|
||||
dtw_token_timestamps: boolish,
|
||||
dtw_aheads_preset: Integer,
|
||||
dtw_n_top: Integer | nil,
|
||||
?use_gpu: boolish,
|
||||
?flash_attn: boolish,
|
||||
?gpu_device: Integer,
|
||||
?dtw_token_timestamps: boolish,
|
||||
?dtw_aheads_preset: Integer,
|
||||
?dtw_n_top: Integer | nil,
|
||||
) -> instance
|
||||
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
|
|
@ -444,6 +455,9 @@ module Whisper
|
|||
def abort_on: { (Object user_data) -> boolish } -> void
|
||||
end
|
||||
|
||||
module LogSettable
|
||||
end
|
||||
|
||||
class Model
|
||||
def self.pre_converted_models: () -> Hash[String, Model::URI]
|
||||
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
|
||||
|
|
@ -474,6 +488,8 @@ module Whisper
|
|||
end
|
||||
|
||||
class Segment
|
||||
include Output::Segment
|
||||
|
||||
type deconstructed_keys = {
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
|
|
@ -514,9 +530,6 @@ module Whisper
|
|||
#
|
||||
def each_token: { (Token) -> void } -> void
|
||||
| () -> Enumerator[Token]
|
||||
def to_srt_cue: () -> String
|
||||
def to_webvtt_cue: () -> String
|
||||
|
||||
|
||||
# Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next`
|
||||
#
|
||||
|
|
@ -528,7 +541,7 @@ module Whisper
|
|||
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
module Token
|
||||
class Token
|
||||
type deconstructed_keys = {
|
||||
id: (Integer | nil),
|
||||
tid: (Integer | nil),
|
||||
|
|
@ -598,6 +611,336 @@ module Whisper
|
|||
def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
module Parakeet
|
||||
extend LogSettable
|
||||
|
||||
VERSION: String
|
||||
|
||||
# Control logging output. The default behavior is to print to stderr.
|
||||
#
|
||||
def self.log_set: (nil, Object? user_data) -> nil
|
||||
| (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
class Context
|
||||
include Output::Context
|
||||
|
||||
# Load a Parakeet model from the given file path.
|
||||
#
|
||||
def self.new: (String | path | ::URI::HTTP, ?Params) -> instance
|
||||
|
||||
# Transcribe a single audio file.
|
||||
#
|
||||
def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self
|
||||
|
||||
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
# Not thread safe for the same context.
|
||||
#
|
||||
# The second argument `samples` must be an array of samples, respond to `:length`,
|
||||
# or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
|
||||
#
|
||||
def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self
|
||||
| (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self
|
||||
|
||||
# Number of generated text segments.
|
||||
#
|
||||
def full_n_segments: () -> Integer
|
||||
|
||||
# Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t0: (Integer segment_index) -> Integer
|
||||
|
||||
# End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t1: (Integer segment_index) -> Integer
|
||||
|
||||
# Text of a segment indexed by `segment_index`.
|
||||
#
|
||||
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||
#
|
||||
def full_get_segment_text: (Integer segment_index) -> String
|
||||
|
||||
# Number of tokens in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_n_tokens: (Integer segment_index) -> Integer
|
||||
|
||||
# Text of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_text: (Integer segment_index, Integer token_index) -> String
|
||||
|
||||
# Token id of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer
|
||||
|
||||
# Probability of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_p: (Integer segment_index, Integer token_index) -> Float
|
||||
|
||||
# Token data of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_data: (Integer segment_index, Integer token_index) -> Token
|
||||
|
||||
def model: () -> Model
|
||||
|
||||
# Yields each Whisper::Parakeet::Segment:
|
||||
#
|
||||
# parakeet.transcribe("path/to/audio.wav", params)
|
||||
# parakeet.each_segment do |segment|
|
||||
# puts segment.text
|
||||
# end
|
||||
#
|
||||
# Returns an `Enumerator` if no block given:
|
||||
#
|
||||
# parakeet.transcribe("path/to/audio.wav", params)
|
||||
# enum = parakeet.each_segment
|
||||
# enum.to_a # => [#<Whisper::Parakeet::Segment>, ...]
|
||||
#
|
||||
def each_segment: { (Segment) -> void } -> void
|
||||
| () -> Enumerator[Segment]
|
||||
|
||||
class Params
|
||||
def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance
|
||||
def use_gpu: () -> boolish
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
def gpu_device: () -> Integer
|
||||
def gpu_device=: (Integer) -> Integer
|
||||
end
|
||||
end
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
?n_threads: Integer,
|
||||
?offset_ms: Integer,
|
||||
?duration_ms: Integer,
|
||||
?no_context: boolish,
|
||||
?audio_ctx: Integer,
|
||||
?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void,
|
||||
?new_segment_callback_user_data: Object,
|
||||
?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void,
|
||||
?new_token_callback_user_data: Object,
|
||||
?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void,
|
||||
?progress_callback_user_data: Object,
|
||||
?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish,
|
||||
?encoder_begin_callback_user_data: Object,
|
||||
?abort_callback: ^(Object user_data) -> boolish,
|
||||
?abort_callback_user_data: Object
|
||||
) -> instance
|
||||
|
||||
# Number of threads to use.
|
||||
#
|
||||
def n_threads=: (Integer) -> Integer
|
||||
def n_threads: () -> Integer
|
||||
|
||||
# Start offset in ms.
|
||||
#
|
||||
def offset_ms=: (Integer) -> Integer
|
||||
def offset_ms: () -> Integer
|
||||
|
||||
# Audio duration to process in ms.
|
||||
#
|
||||
def duration_ms=: (Integer) -> Integer
|
||||
def duration_ms: () -> Integer
|
||||
|
||||
# If `true`, does not use past transcription (if any) as context.
|
||||
#
|
||||
def no_context=: (boolish) -> boolish
|
||||
def no_context: () -> (true | false)
|
||||
|
||||
# Overwrite the audio context size. `0` uses the default value.
|
||||
#
|
||||
def audio_ctx=: (Integer) -> Integer
|
||||
def audio_ctx: () -> Integer
|
||||
|
||||
# Sets new segment callback, called for every newly generated text segment.
|
||||
#
|
||||
# params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
#
|
||||
def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void)
|
||||
def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of new segment callback.
|
||||
#
|
||||
def new_segment_callback_user_data=: (Object?) -> Object?
|
||||
def new_segment_callback_user_data: () -> Object?
|
||||
|
||||
# Sets token callback, called for every newly predicted token.
|
||||
#
|
||||
def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void)
|
||||
def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of token callback.
|
||||
#
|
||||
def new_token_callback_user_data=: (Object?) -> Object?
|
||||
def new_token_callback_user_data: () -> Object?
|
||||
|
||||
# Sets progress callback, called on each progress update.
|
||||
#
|
||||
# +progress+ is an Integer between 0 and 100.
|
||||
#
|
||||
def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void)
|
||||
def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of progress callback.
|
||||
#
|
||||
def progress_callback_user_data=: (Object?) -> Object?
|
||||
def progress_callback_user_data: () -> Object?
|
||||
|
||||
# Sets encoder begin callback, called each time before the encoder starts.
|
||||
#
|
||||
# If it returns `false`, the computation is aborted.
|
||||
#
|
||||
def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish)
|
||||
def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of encoder begin callback.
|
||||
#
|
||||
def encoder_begin_callback_user_data=: (Object?) -> Object?
|
||||
def encoder_begin_callback_user_data: () -> Object?
|
||||
|
||||
# Sets abort callback, called each time before ggml computation starts.
|
||||
#
|
||||
def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish)
|
||||
def abort_callback: () -> ((^(Object user_data) -> boolish) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of abort callback.
|
||||
#
|
||||
def abort_callback_user_data=: (Object?) -> Object?
|
||||
def abort_callback_user_data: () -> Object?
|
||||
|
||||
# Hook called on new segment. Yields each Whisper::Parakeet::Segment.
|
||||
#
|
||||
def on_new_segment: { (Segment) -> void } -> void
|
||||
|
||||
# Hook called on new token. Yields each Whisper::Parakeet::Token.
|
||||
#
|
||||
def on_new_token: { (Token) -> void } -> void
|
||||
|
||||
# Hook called on progress update. Yields each progress `Integer` between 0 and 100.
|
||||
#
|
||||
def on_progress: { (Integer progress) -> void } -> void
|
||||
|
||||
# Hook called each time before the encoder starts.
|
||||
#
|
||||
def on_encoder_begin: { () -> boolish } -> void
|
||||
|
||||
# Call block to determine whether abort or not. Return `true` when you want to abort.
|
||||
#
|
||||
def abort_on: { () -> boolish } -> void
|
||||
end
|
||||
|
||||
class Segment
|
||||
include Output::Segment
|
||||
|
||||
type deconstructed_keys = {
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
text: (String | nil)
|
||||
}
|
||||
|
||||
# Start time in milliseconds.
|
||||
#
|
||||
def start_time: () -> Integer
|
||||
|
||||
# End time in milliseconds.
|
||||
#
|
||||
def end_time: () -> Integer
|
||||
|
||||
# Text of the segment.
|
||||
#
|
||||
def text: () -> String
|
||||
|
||||
# Yields each Whisper::Parakeet::Token:
|
||||
#
|
||||
# parakeet.each_segment.first.each_token do |token|
|
||||
# p token
|
||||
# end
|
||||
#
|
||||
# Returns an `Enumerator` if no block is given:
|
||||
#
|
||||
# parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...]
|
||||
#
|
||||
def each_token: { (Token) -> void } -> void
|
||||
| () -> Enumerator[Token]
|
||||
|
||||
# Possible keys: `:start_time`, `:end_time`, `:text`
|
||||
#
|
||||
def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
class Token
|
||||
type deconstructed_keys = {
|
||||
id: (Integer | nil),
|
||||
duration_idx: (Integer | nil),
|
||||
duration_value: (Integer | nil),
|
||||
frame_index: (Integer | nil),
|
||||
probability: (Float | nil),
|
||||
log_probability: (Float | nil),
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
word_start: ((true | false) | nil),
|
||||
text: (String | nil),
|
||||
}
|
||||
|
||||
# Token ID.
|
||||
#
|
||||
def id: () -> Integer
|
||||
|
||||
# Index into the model's durations array.
|
||||
#
|
||||
def duration_idx: () -> Integer
|
||||
|
||||
# Actual duration value.
|
||||
#
|
||||
def duration_value: () -> Integer
|
||||
|
||||
# Frame index of the token.
|
||||
#
|
||||
def frame_index: () -> Integer
|
||||
|
||||
# Probability of the token.
|
||||
#
|
||||
def probability: () -> Float
|
||||
|
||||
# Log probability of the token.
|
||||
#
|
||||
def log_probability: () -> Float
|
||||
|
||||
# Start time of the token in milliseconds.
|
||||
#
|
||||
def start_time: () -> Integer
|
||||
|
||||
# End time of the token in milliseconds.
|
||||
#
|
||||
def end_time: () -> Integer
|
||||
|
||||
# Whether this token is the start of a word.
|
||||
#
|
||||
def word_start?: () -> (true | false)
|
||||
|
||||
# Get the token text of the token.
|
||||
#
|
||||
def text: () -> String
|
||||
|
||||
def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
class Model
|
||||
def n_vocab: () -> Integer
|
||||
def n_audio_ctx: () -> Integer
|
||||
def n_audio_state: () -> Integer
|
||||
def n_audio_head: () -> Integer
|
||||
def n_audio_layer: () -> Integer
|
||||
def n_mels: () -> Integer
|
||||
def ftype: () -> Integer
|
||||
end
|
||||
end
|
||||
|
||||
module VAD
|
||||
class Params
|
||||
def self.new: (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ require_relative "jfk_reader/jfk_reader"
|
|||
class TestBase < Test::Unit::TestCase
|
||||
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
|
||||
|
||||
Parakeet = Whisper::Parakeet
|
||||
|
||||
class << self
|
||||
def whisper
|
||||
return @whisper if @whisper
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class TestCallback < TestBase
|
|||
return false
|
||||
}
|
||||
@whisper.transcribe(@audio, @params)
|
||||
sleep 0.5 # wait for logs dequeued
|
||||
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
|
||||
Whisper.log_set ->(level, buffer, user_data) {}, nil
|
||||
end
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
require_relative "helper"
|
||||
require "stringio"
|
||||
|
||||
class TestParakeet < TestBase
|
||||
def test_log_set
|
||||
log_callback = Parakeet.instance_variable_get("@log_callback")
|
||||
user_data = Parakeet.instance_variable_get("@log_callback_user_data")
|
||||
|
||||
$stdout = StringIO.new
|
||||
Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil
|
||||
Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
sleep 0.1
|
||||
$stdout.rewind
|
||||
logs = $stdout.string
|
||||
assert_match /loading model from/, logs
|
||||
ensure
|
||||
$stdout = STDOUT
|
||||
Parakeet.log_set log_callback, user_data
|
||||
end
|
||||
|
||||
def test_system_info_str
|
||||
assert_match /\APARAKEET : /, Parakeet.system_info_str
|
||||
end
|
||||
|
||||
def test_version
|
||||
assert_instance_of String, Parakeet::VERSION
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetCallback < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
@params = Parakeet::Params.new
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
end
|
||||
|
||||
def test_new_segment_callback
|
||||
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
||||
assert_kind_of Integer, n_new
|
||||
assert n_new > 0
|
||||
assert_same @parakeet, context
|
||||
|
||||
n_segments = context.full_n_segments
|
||||
n_new.times do |i|
|
||||
i_segment = n_segments - 1 + i
|
||||
start_time = context.full_get_segment_t0(i_segment) * 10
|
||||
end_time = context.full_get_segment_t1(i_segment) * 10
|
||||
text = context.full_get_segment_text(i_segment)
|
||||
|
||||
assert_kind_of Integer, start_time
|
||||
assert start_time >= 0
|
||||
assert_kind_of Integer, end_time
|
||||
assert end_time > 0
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0
|
||||
end
|
||||
}
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
seg = nil
|
||||
index = 0
|
||||
@params.on_new_segment do |segment|
|
||||
assert_instance_of Parakeet::Segment, segment
|
||||
if index == 0
|
||||
seg = segment
|
||||
assert_equal 0, segment.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text)
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
assert_equal 0, seg.start_time
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text
|
||||
end
|
||||
|
||||
def test_on_new_token
|
||||
index = 0
|
||||
@params.on_new_token do |token|
|
||||
assert_instance_of Parakeet::Token, token
|
||||
if index == 0
|
||||
assert_instance_of Integer, token.start_time
|
||||
assert_match "▁And", token.text
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
|
||||
def test_on_progress
|
||||
first = nil
|
||||
@params.on_progress do |progress|
|
||||
assert_kind_of Integer, progress
|
||||
assert 0 <= progress && progress <= 100
|
||||
first = progress if first.nil?
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
|
||||
assert_equal 0, first
|
||||
end
|
||||
|
||||
def test_on_encoder_begin
|
||||
i = 0
|
||||
@params.on_encoder_begin do
|
||||
i += 1
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
|
||||
assert i > 0
|
||||
end
|
||||
|
||||
def test_abort_on
|
||||
do_abort = false
|
||||
@params.on_new_segment do |segment|
|
||||
do_abort = true if segment.text.match?(/ask/)
|
||||
end
|
||||
i = 0
|
||||
@params.abort_on do
|
||||
i += 1
|
||||
do_abort
|
||||
end
|
||||
|
||||
@parakeet.transcribe(AUDIO, @params) rescue nil
|
||||
|
||||
assert i > 0
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
require_relative "helper"
|
||||
require "stringio"
|
||||
|
||||
class TestParakeetContext < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
@params = Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Context, @parakeet
|
||||
end
|
||||
|
||||
def test_new_with_params
|
||||
log_callback = Parakeet.instance_variable_get(:@log_callback)
|
||||
user_data = Parakeet.instance_variable_get(:@log_callback_user_data)
|
||||
begin
|
||||
logs = ""
|
||||
Parakeet.log_set proc {|level, message| logs << message}, nil
|
||||
params = Parakeet::Context::Params.new(use_gpu: false)
|
||||
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params)
|
||||
assert_instance_of Parakeet::Context, parakeet
|
||||
assert_match /use gpu\s+=\s+0/, logs
|
||||
ensure
|
||||
Parakeet.log_set log_callback, user_data
|
||||
end
|
||||
end
|
||||
|
||||
sub_test_case "full" do
|
||||
def setup
|
||||
super
|
||||
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
|
||||
end
|
||||
|
||||
def test_full
|
||||
@parakeet.full @params, @samples, @samples.length
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text
|
||||
end
|
||||
|
||||
def test_full_without_length
|
||||
@parakeet.full(@params, @samples)
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_enumerator
|
||||
samples = @samples.each
|
||||
@parakeet.full @params, samples, @samples.length
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_enumerator_without_length
|
||||
samples = @samples.each
|
||||
assert_raise ArgumentError do
|
||||
@parakeet.full @params, samples
|
||||
end
|
||||
end
|
||||
|
||||
def test_full_enumerator_with_too_large_length
|
||||
samples = @samples.each.take(10).to_enum
|
||||
assert_raise StopIteration do
|
||||
@parakeet.full @params, samples, 11
|
||||
end
|
||||
end
|
||||
|
||||
def test_full_with_memory_view
|
||||
samples = JFKReader.new(AUDIO)
|
||||
@parakeet.full @params, samples
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_with_memroy_view_gc
|
||||
samples = JFKReader.new(AUDIO)
|
||||
@parakeet.full(@params, samples)
|
||||
GC.start
|
||||
require "fiddle"
|
||||
Fiddle::MemoryView.export samples do |view|
|
||||
assert_equal 176000, view.to_s.unpack("#{view.format}*").length
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe
|
||||
assert_nothing_raised do
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe_with_pathname
|
||||
assert_nothing_raised do
|
||||
@parakeet.transcribe Pathname(AUDIO), @params
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe_with_nothing
|
||||
assert_raise_message(/open/) do
|
||||
@parakeet.transcribe "nothing", @params
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetContextParams < TestBase
|
||||
def setup
|
||||
@params = Parakeet::Context::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Context::Params, @params
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
assert_true @params.use_gpu
|
||||
assert_instance_of Integer, @params.gpu_device
|
||||
end
|
||||
|
||||
def test_attribute_writer
|
||||
@params.use_gpu = false
|
||||
assert_false @params.use_gpu
|
||||
|
||||
@params.gpu_device = 2
|
||||
assert_equal 2, @params.gpu_device
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetModel < TestBase
|
||||
def test_model
|
||||
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
assert_instance_of Parakeet::Model, parakeet.model
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
model = parakeet.model
|
||||
|
||||
assert_equal 10, model.n_vocab
|
||||
assert_equal 3200, model.n_audio_ctx
|
||||
assert_equal 8, model.n_audio_state
|
||||
assert_equal 2, model.n_audio_head
|
||||
assert_equal 1, model.n_audio_layer
|
||||
assert_equal 16, model.n_mels
|
||||
assert_equal 0, model.ftype
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
require_relative "helper"
|
||||
require "etc"
|
||||
|
||||
class TestParakeetParams < TestBase
|
||||
PARAM_NAMES = [
|
||||
:n_threads,
|
||||
:offset_ms,
|
||||
:duration_ms,
|
||||
:no_context,
|
||||
:audio_ctx
|
||||
]
|
||||
|
||||
def setup
|
||||
@params = Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Params, @params
|
||||
end
|
||||
|
||||
def test_n_threads
|
||||
assert_equal [4, Etc.nprocessors].min, @params.n_threads
|
||||
|
||||
@params.n_threads = 1
|
||||
assert_equal 1, @params.n_threads
|
||||
end
|
||||
|
||||
def test_offset_ms
|
||||
assert_equal 0, @params.offset_ms
|
||||
|
||||
@params.offset_ms = 10_000
|
||||
assert_equal 10_000, @params.offset_ms
|
||||
end
|
||||
|
||||
def test_duration_ms
|
||||
assert_equal 0, @params.duration_ms
|
||||
|
||||
@params.duration_ms = 60_000
|
||||
assert_equal 60_000, @params.duration_ms
|
||||
end
|
||||
|
||||
def test_no_context
|
||||
assert_equal true, @params.no_context
|
||||
|
||||
@params.no_context = false
|
||||
assert_equal false, @params.no_context
|
||||
end
|
||||
|
||||
def test_audio_ctx
|
||||
assert_equal 0, @params.audio_ctx
|
||||
|
||||
@params.audio_ctx = 1
|
||||
assert_equal 1, @params.audio_ctx
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Parakeet::Params.new(n_threads: 1)
|
||||
assert_equal 1, params.n_threads
|
||||
assert_equal 0, params.offset_ms
|
||||
end
|
||||
|
||||
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
|
||||
def test_new_with_kw_args_default_values(param)
|
||||
default_value = @params.send(param)
|
||||
value = case [param, default_value]
|
||||
in [*, true | false]
|
||||
!default_value
|
||||
in [*, Integer]
|
||||
default_value + 1
|
||||
end
|
||||
params = Parakeet::Params.new(param => value)
|
||||
assert_equal value, params.send(param)
|
||||
|
||||
PARAM_NAMES.reject {|name| name == param}.each do |name|
|
||||
assert_equal @params.send(name), params.send(name)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetSegment < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
@parakeet.transcribe AUDIO, Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_segment
|
||||
whole_text = ""
|
||||
@parakeet.each_segment do |segment|
|
||||
assert_instance_of Parakeet::Segment, segment
|
||||
assert_kind_of Integer, segment.start_time
|
||||
assert segment.end_time >= segment.start_time
|
||||
assert_kind_of String, segment.text
|
||||
whole_text << segment.text
|
||||
end
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys
|
||||
segment = @parakeet.each_segment.first
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text
|
||||
}
|
||||
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text])
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_nil
|
||||
segment = @parakeet.each_segment.first
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text
|
||||
}
|
||||
assert_equal expected, segment.deconstruct_keys(nil)
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetToken < TestBase
|
||||
ATTRS = %i[
|
||||
id
|
||||
duration_idx
|
||||
duration_value
|
||||
frame_index
|
||||
probability
|
||||
log_probability
|
||||
start_time
|
||||
end_time
|
||||
word_start?
|
||||
text
|
||||
]
|
||||
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
|
||||
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
params = Parakeet::Params.new
|
||||
parakeet.transcribe AUDIO, params
|
||||
@segment = parakeet.each_segment.first
|
||||
end
|
||||
|
||||
def test_each_token
|
||||
i = 0
|
||||
@segment.each_token do |token|
|
||||
i += 1
|
||||
assert_instance_of Parakeet::Token, token
|
||||
end
|
||||
assert_equal 38, i
|
||||
end
|
||||
|
||||
def test_each_token_without_block
|
||||
assert_instance_of Enumerator, @segment.each_token
|
||||
end
|
||||
|
||||
def test_token
|
||||
token = @segment.each_token.first
|
||||
|
||||
assert_instance_of Parakeet::Token, token
|
||||
assert_instance_of Integer, token.id
|
||||
assert_instance_of Integer, token.duration_idx
|
||||
assert_instance_of Integer, token.duration_value
|
||||
assert_instance_of Integer, token.frame_index
|
||||
assert_instance_of Float, token.probability
|
||||
assert_instance_of Float, token.log_probability
|
||||
assert_instance_of Integer, token.start_time
|
||||
assert_instance_of Integer, token.end_time
|
||||
assert_instance_of String, token.text
|
||||
end
|
||||
|
||||
def test_text
|
||||
assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."],
|
||||
@segment.each_token.collect(&:text)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_nil
|
||||
token = @segment.each_token.first
|
||||
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
|
||||
assert_equal expected, token.deconstruct_keys(nil)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_keys
|
||||
token = @segment.each_token.first
|
||||
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
|
||||
assert_equal expected, token.deconstruct_keys(expected.keys)
|
||||
end
|
||||
end
|
||||
|
|
@ -9,7 +9,7 @@ class TestVADSegment < TestBase
|
|||
end
|
||||
|
||||
assert_raise do
|
||||
segments.end_time
|
||||
segment.end_time
|
||||
end
|
||||
|
||||
assert_raise do
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ class TestWhisper < TestBase
|
|||
}
|
||||
Whisper.log_set log_callback, user_data
|
||||
Whisper::Context.new("base.en")
|
||||
sleep 0.1 # wait for logs dequeued
|
||||
|
||||
assert logs.length > 30
|
||||
logs.each do |log|
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Gem::Specification.new do |s|
|
|||
s.test_files = s.files.select {|file| file.start_with? "test/"}
|
||||
|
||||
s.extensions << 'ext/extconf.rb'
|
||||
s.required_ruby_version = '>= 3.1.0'
|
||||
s.required_ruby_version = '>= 3.3.0'
|
||||
|
||||
#### Documentation and testing.
|
||||
s.homepage = 'https://github.com/ggml-org/whisper.cpp'
|
||||
|
|
|
|||
Loading…
Reference in New Issue