ruby : add support for Parakeet (#3885)

* Add Whisper::Parakeet::Params

* Add tests for Parakeet::Params

* Remove unused variabel

* Add callbacks to Parakeet::Params

* Group callback and user_data params

* Undefine local macros

* Define GetParakeetParams

* Remove unused variable

* Use ITERATE_CALLBACK_PARAMS

* Use ITERATE_CALLBACK_PARAMS instead of ITERATE_USER_DATA_PARAMS

* Fix memsize

* Remove unnecessary macros

* Simplify params registration

* Define Parakeet

* Add hook methods to Parakeet::Params

* Fix typo

* Check callback container in GetParakeetParams

* Reduce if

* Free parakeet_full_params

* Implement Parakeet::Context#initialize

* Add TestParakeetContext

* Add Parakeet::Segment

* Prevent double-free

* Add Parakeet::Context#transcribe

* Add Parakeet::Context#each_segment

* Define Parakeet::Segment attributes

* Define Parakeet::Segment#deconstruct_keys

* Add tests for Parakeet::Segment#deconstruct_keys

* Run Parakeet::Context#transcribe without GVL

* Make it to abort for Parakeet

* Add Parakeet.log_set

* Define Parakeet::Token

* Define Parakeet::Segment#each_token

* Implement some hooks of Parakeet::Params

* Convert int to VALUE

* Implement hooks for Parakeet

* Implement Parakeet::Context#full

* Add tests for Parakeet::Context#full

* Add Parakeet to RBS

* Fix ruby_whisper_parakeet_params_free

* Free ruby_whisper_parakeet_context

* Add tests for hooks

* Add Parakeet section to README

* Add more attributes of Parakeet::Context

* Add tests for Parakeet::Context's attributes

* Update RBS

* Register parakeet-tdt-0.6b-v3

* Narrow scope of log constants

* Extract activate and deactivate of log_queue

* Make start_log_callback_thread private

* Don't call start_log_callback_thread unncecessarilly

* Early return from log_queue_enqueue when not active

* Gropu log_queue members

* is_active -> is_open

* Fix English

* Share parakeet full body function

* ruby_whisper_parakeet_abort_callback_user_data -> ruby_whisper_abort_callback_user_data

* NULL check for callback containers

* Fix Parakeet.log_set

* Omit Parakeet tests on CI

* Extract Whisper::LogSettable

* Join log callback thread in a log queue function

* Revert Join log callback thread in a log queue function

* Extract output methods to modules

* Move Parakeet init functions into init_parakeet()

* Add output methods to Parakeet classes

* Add Parakeet's output methods to RBS

* Use Whisper::Output in RBS

* Add LogSettable to RBS

* Fix module Token -> class Token

* Add Parakeet::Model

* Add test for Parakeet::Model

* Add Parakeet::Model to RBS

* Move position of Parakeet::Model in RBS

* Parakeet -> TestBase::Parakeet

* Add Parakeet::Context#model in RBS

* Add Whisper::Output

* Fix nil check

* Define ruby_whisper_parakeet_model_memsize

* Fix order of declaration in ruby_whisper_parakeet_model_get_xxx

* Define Parakeet.system_info_str

* Add test for Parakeet.system_info_str

* Add signature of Parakeet.system_info_str

* Define Parakeet::VERSION

* Add test for Parakeet::VERSION

* Add signature of Parakeet::VERSION

* Add Parakeet::Context::Params

* Make Parakeet::Context.new accept Context::Params

* Add test for Parakeet::Context.new with Context::Params

* Update RBS

* Remove params from Parakeet::Params which are moved from whisper_parakeet_full_params

* Remove tests for removed params

* Make Parakeet tests follow original behavior changes

* Add Parakeet model shortcuts

* Alloc token data in factory instead of alloc func

* Fix variable name

* Update RBS

* Refactor log settable module

* Use log settable for Whisper

* Address deadlock

* Make test follow change of log queue implementation

* Refactor to make abort callback use the same way to parakeet's way

* Remove redundant structs

* Fix test name

* Fix README

* Add missing parallel transcription

* Fix test for parakeet info

* Remove removed params

* Wait for logs dequeued

* Fix instance variable name

* Load etc feature

* Remove unnecessary comment

* Remove unnecessary thread safety check

* Remove outdated comment

* Skip downloading model if cache exists

* Change Hugging Face URI for Parakeet models

* Bump required Ruby version to 3.3

* Fix English
This commit is contained in:
KITAITI Makoto 2026-06-17 13:42:09 +09:00 committed by GitHub
parent 9efddafb91
commit 0d14756929
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 3004 additions and 332 deletions

View File

@ -27,6 +27,6 @@ jobs:
steps:
- uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0
with:
ruby-version: '3.2'
ruby-version: '3.3'
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6
- run: rake test

View File

@ -396,6 +396,37 @@ whisper
.full(Whisper::Params.new, samples)
```
### Parakeet ###
whispercpp gem now supports NVIDIA's ASR model Parakeet.
If you want to use Parakeet instead of Whisper, the API should feel familiar.
In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way.
```ruby
require "whisper"
# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem.
Parakeet = Whisper::Parakeet
parakeet = Parakeet::Context.new("path/to/model")
params = Parakeet::Params.new(
no_context: true
)
parakeet
.transcribe("path/to/audio.wav", params)
.each_segment do |segment|
puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}"
end
```
The main differences are:
* Namespace is `Whisper::Parakeet`.
* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks.
Custom context params
---------------------

View File

@ -84,6 +84,21 @@ else
end
end
TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin"
TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin"))
TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d")
directory TEST_PARAKEET_MODEL_DIR
if File.exist? TEST_PARAKEET_MODEL_SRC
file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t|
symlink t.source, t.name
end
else
require "open-uri"
file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t|
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read
end
end
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
chdir "test/jfk_reader" do
@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
end
CLEAN.include TEST_MEMORY_VIEW
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL]

View File

@ -1,19 +1,29 @@
#include "ruby_whisper.h"
VALUE mWhisper;
VALUE mLogSettable;
VALUE mVAD;
VALUE mParakeet;
VALUE cContext;
VALUE cParams;
VALUE cVADContext;
VALUE cVADParams;
VALUE cVADSegments;
VALUE cVADSegment;
VALUE cParakeetContext;
VALUE cParakeetContextParams;
VALUE cParakeetParams;
VALUE cParakeetSegment;
VALUE cParakeetModel;
VALUE eError;
VALUE cSegment;
VALUE cToken;
VALUE cModel;
VALUE mOutputContext;
VALUE mOutputSegment;
ID id_to_s;
ID id_call;
ID id___method__;
@ -27,9 +37,11 @@ ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_n_processors;
static bool is_log_callback_finalized = false;
static bool is_ruby_log_callback_present = false;
ID id_extended;
ID id_start_log_callback_thread;
ID id_log_callback_thread;
ID id_alive_p;
ID id_join;
// High level API
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD);
extern void init_ruby_whisper_vad_context(VALUE *mVAD);
extern void init_ruby_whisper_vad_segment(VALUE *mVAD);
extern void init_ruby_whisper_vad_segments(VALUE *mVAD);
extern void init_ruby_whisper_parakeet(VALUE *mWhisper);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
static ruby_whisper_log_queue whisper_log_queue;
LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set)
/*
* call-seq:
* lang_max_id -> Integer
@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) {
return rb_str_new2(whisper_print_system_info());
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true;
return Qnil;
}
typedef struct {
int level;
const char * buffer;
} call_log_callbacks_args;
static void*
call_log_callbacks(void *v_args) {
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
if (NIL_P(log_callback)) {
return NULL;
}
call_log_callbacks_args *args = (call_log_callbacks_args *)v_args;
VALUE user_data = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data);
return NULL;
}
static void
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
if (is_log_callback_finalized) {
return;
}
if (!is_ruby_log_callback_present) {
return;
}
call_log_callbacks_args args = {
level,
buffer,
};
if (ruby_thread_has_gvl_p()) {
call_log_callbacks((void *)&args);
} else {
rb_thread_call_with_gvl(call_log_callbacks, (void *)&args);
}
}
/*
* call-seq:
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
rb_iv_set(self, "log_callback", log_callback);
rb_iv_set(self, "user_data", user_data);
if (!NIL_P(log_callback)) {
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
}
if (NIL_P(log_callback)) {
whisper_log_set(NULL, NULL);
is_ruby_log_callback_present = false;
} else {
whisper_log_set(ruby_whisper_log_callback, NULL);
is_ruby_log_callback_present = true;
}
return Qnil;
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
@ -189,9 +133,19 @@ void Init_whisper() {
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_n_processors = rb_intern("n_processors");
id_extended = rb_intern("extended");
id_start_log_callback_thread = rb_intern("start_log_callback_thread");
id_log_callback_thread = rb_intern("@log_callback_thread");
id_alive_p = rb_intern("alive?");
id_join = rb_intern("join");
mWhisper = rb_define_module("Whisper");
rb_require("whisper/log_settable");
mLogSettable = rb_path2class("Whisper::LogSettable");
mVAD = rb_define_module_under(mWhisper, "VAD");
rb_require("whisper/output");
mOutputContext = rb_path2class("Whisper::Output::Context");
mOutputSegment = rb_path2class("Whisper::Output::Segment");
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
@ -222,8 +176,8 @@ void Init_whisper() {
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
LOG_SETTABLE_INIT(whisper_log_queue, mWhisper)
cContext = init_ruby_whisper_context(&mWhisper);
init_ruby_whisper_context_params(&cContext);
@ -236,8 +190,10 @@ void Init_whisper() {
init_ruby_whisper_vad_segment(&mVAD);
init_ruby_whisper_vad_segments(&mVAD);
init_ruby_whisper_vad_context(&mVAD);
init_ruby_whisper_parakeet(&mWhisper);
rb_require("whisper/context");
rb_require("whisper/segment");
rb_require("whisper/model/uri");
rb_include_module(cContext, mOutputContext);
rb_include_module(cSegment, mOutputSegment);
}

View File

@ -5,8 +5,12 @@
#include <ruby/version.h>
#include <ruby/util.h>
#include <ruby/thread.h>
#include <ruby/thread_native.h>
#include <ruby/atomic.h>
#include <ruby/memory_view.h>
#include "whisper.h"
#include "parakeet.h"
#include "ruby_whisper_log_settable.h"
#if RUBY_API_VERSION_MAJOR < 4
// Exists but not declared as public API
@ -20,13 +24,28 @@ typedef struct {
VALUE callbacks;
} ruby_whisper_callback_container;
typedef struct {
VALUE *context;
VALUE user_data;
VALUE callback;
VALUE callbacks;
bool is_interrupted;
} ruby_whisper_abort_callback_container;
typedef struct ruby_whisper_abort_callback_user_data {
volatile rb_atomic_t is_interrupted;
ruby_whisper_callback_container *callback_container;
} ruby_whisper_abort_callback_user_data;
typedef struct ruby_whisper_log {
enum ggml_log_level level;
char *text;
size_t length;
size_t capacity;
} ruby_whisper_log;
typedef struct ruby_whisper_log_queue {
rb_nativethread_lock_t lock;
rb_nativethread_cond_t cond;
bool is_open;
size_t head;
size_t tail;
size_t size;
ruby_whisper_log *logs;
} ruby_whisper_log_queue;
typedef struct {
struct whisper_context *context;
@ -42,7 +61,7 @@ typedef struct {
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_callback_container *abort_callback_container;
VALUE vad_params;
} ruby_whisper_params;
@ -84,6 +103,63 @@ typedef struct parsed_samples_t {
bool memview_exported;
} parsed_samples_t;
typedef struct {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} ruby_whisper_full_args;
typedef struct ruby_whisper_full_parallel_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
int n_processors;
} ruby_whisper_full_parallel_args;
typedef struct {
struct parakeet_full_params params;
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *new_token_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_parakeet_params;
typedef struct {
struct parakeet_context_params params;
} ruby_whisper_parakeet_context_params;
typedef struct {
struct parakeet_context *context;
} ruby_whisper_parakeet_context;
typedef struct {
VALUE context;
int index;
} ruby_whisper_parakeet_segment;
typedef struct {
parakeet_token_data *token_data;
VALUE text;
} ruby_whisper_parakeet_token;
typedef struct {
VALUE context;
} ruby_whisper_parakeet_model;
extern ID id_extended;
extern ID id_log_callback_thread;
extern ID id_start_log_callback_thread;
extern ID id_alive_p;
extern ID id_join;
extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text);
extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue);
#define GetContext(obj, rw) do { \
TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
if ((rw)->context == NULL) { \
@ -120,4 +196,47 @@ typedef struct parsed_samples_t {
} \
} while (0)
#define GetParakeetContextParams(obj, rwpcp) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \
} while (0)
#define GetParakeetContext(obj, rwpc) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \
if ((rwpc)->context == NULL) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetParams(obj, rwpp) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \
if (!(rwpp)->new_segment_callback_container || \
!(rwpp)->new_token_callback_container || \
!(rwpp)->progress_callback_container || \
!(rwpp)->encoder_begin_callback_container || \
!(rwpp)->abort_callback_container) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetSegment(obj, rwps) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \
if (!(rwps)->context) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetToken(obj, rwpt) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \
if (!(rwpt)->token_data) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetModel(obj, rwpm) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \
if (NIL_P((rwpm)->context)) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#endif

View File

@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
ID transcribe_option_names[1];
@ -38,21 +38,6 @@ typedef struct fill_samples_args {
int n_samples;
} fill_samples_args;
typedef struct full_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} full_args;
typedef struct full_parallel_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
int n_processors;
} full_parallel_args;
typedef struct full_without_gvl_args {
struct whisper_context *context;
struct whisper_full_params *params;
@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args {
} full_parallel_without_gvl_args;
typedef struct full_ubf_args {
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
} full_ubf_args;
static void
@ -379,7 +364,7 @@ fill_samples(VALUE rb_args)
return Qnil;
}
struct parsed_samples_t
parsed_samples_t
parse_samples(VALUE *samples, VALUE *n_samples)
{
bool memview_available = rb_memory_view_available_p(*samples);
@ -480,20 +465,24 @@ full_ubf(void *rb_args)
{
full_ubf_args *args = (full_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
}
static VALUE
VALUE
full_body(VALUE rb_args)
{
full_args *args = (full_args *)rb_args;
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
ruby_whisper *rw;
ruby_whisper_params *rwp;
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, 1);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
struct full_without_gvl_args full_without_gvl_args = {
rw->context,
@ -503,7 +492,7 @@ full_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_without_gvl_args.result);
@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
VALUE n_samples = argc == 2 ? Qnil : argv[2];
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
full_args args = {
ruby_whisper_full_args args = {
&self,
&argv[0],
parsed.samples,
@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args)
return NULL;
}
static VALUE
VALUE
full_parallel_body(VALUE rb_args)
{
full_parallel_args *args = (full_parallel_args *)rb_args;
ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args;
ruby_whisper *rw;
ruby_whisper_params *rwp;
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, args->n_processors);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
rw->context,
@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_parallel_without_gvl_args.result);
@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
break;
}
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
const full_parallel_args args = {
const ruby_whisper_full_parallel_args args = {
&self,
&argv[0],
parsed.samples,

View File

@ -0,0 +1,180 @@
#include "ruby_whisper.h"
#define LOG_QUEUE_CAPACITY 256
#define LOG_DEFAULT_CAPACITY 1024
void
ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_initialize(&log_queue->lock);
rb_native_cond_initialize(&log_queue->cond);
log_queue->head = 0;
log_queue->tail = 0;
log_queue->size = 0;
log_queue->is_open = true;
log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY);
for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) {
// we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL
// this doesn't be freed because log queue lives until the end of process
char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY);
if (!slot) {
rb_raise(rb_eRuntimeError, "Could not allocate memory for log text");
}
ruby_whisper_log log = {
0,
slot,
0,
LOG_QUEUE_CAPACITY,
};
log_queue->logs[i] = log;
}
}
void
ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_lock(&log_queue->lock);
log_queue->is_open = true;
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
void
ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_lock(&log_queue->lock);
log_queue->is_open = false;
rb_native_cond_broadcast(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
static size_t
calc_enough_cap(size_t len)
{
size_t quot = len / LOG_DEFAULT_CAPACITY;
size_t rem = len % LOG_DEFAULT_CAPACITY;
return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY;
}
void
ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text)
{
rb_nativethread_lock_lock(&log_queue->lock);
if (!log_queue->is_open) {
rb_nativethread_lock_unlock(&log_queue->lock);
return;
}
size_t len = strlen(text);
ruby_whisper_log *log = &log_queue->logs[log_queue->head];
if (len > log->capacity) {
size_t new_cap = calc_enough_cap(len);
// we cannot call Ruby API like REALLOC_N because this function is called without GVL
char *slot = realloc(log->text, new_cap);
if (!slot) {
rb_nativethread_lock_unlock(&log_queue->lock);
return;
}
log->text = slot;
log->capacity = new_cap;
}
// we cannot call Ruby API like MEMCPY because this function is called without GVL
memcpy(log->text, text, sizeof(char) * len);
log->length = len;
log->level = level;
log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY;
bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY;
log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1;
if (is_full) {
log_queue->tail = log_queue->head;
}
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
static void*
ruby_whisper_log_queue_wait(void *args)
{
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
rb_native_cond_wait(&log_queue->cond, &log_queue->lock);
rb_nativethread_lock_unlock(&log_queue->lock);
return NULL;
}
static void
ruby_whisper_log_queue_wait_ubf(void *args)
{
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
rb_native_cond_broadcast(&log_queue->cond);
}
typedef struct {
enum ggml_log_level level;
size_t length;
char *text;
} log_snapshot;
VALUE
ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue)
{
log_snapshot logs[LOG_QUEUE_CAPACITY];
rb_nativethread_lock_lock(&log_queue->lock);
while (log_queue->size == 0 && log_queue->is_open) {
rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue);
rb_nativethread_lock_lock(&log_queue->lock);
}
if (log_queue->size == 0 && !log_queue->is_open) {
rb_native_cond_broadcast(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
return Qnil;
}
size_t size = log_queue->size;
ruby_whisper_log *log;
size_t i;
for (i = 0; i < size; i++) {
log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY];
logs[i].level = log->level;
logs[i].length = log->length;
char *text = malloc(log->length);
if (!text) {
logs[i].text = NULL;
continue;
}
logs[i].text = text;
memcpy(logs[i].text, log->text, log->length);
}
log_queue->size = 0;
log_queue->tail = log_queue->head;
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
VALUE rb_logs = rb_ary_new2(size);
VALUE rb_text;
for (i = 0; i < size; i++) {
if (!logs[i].text) {
continue;
}
rb_text = rb_str_new(logs[i].text, logs[i].length);
free(logs[i].text);
rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text));
}
return rb_logs;
}

View File

@ -0,0 +1,47 @@
#ifndef RUBY_WHISPER_LOG_SETTABLE_H
#define RUBY_WHISPER_LOG_SETTABLE_H
#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \
static VALUE \
ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \
{ \
return ruby_whisper_log_queue_drain(&log_queue); \
} \
static void \
ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \
{ \
ruby_whisper_log_queue_enqueue(&log_queue, level, text); \
} \
static VALUE \
ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \
{ \
rb_iv_set(self, "@log_callback", log_callback); \
rb_iv_set(self, "@log_callback_user_data", user_data); \
if (NIL_P(log_callback)) { \
log_set(NULL, NULL); \
} else { \
ruby_whisper_log_queue_open(&log_queue); \
rb_funcall((mod), id_start_log_callback_thread, 0); \
log_set(ruby_whisper_##log_queue##_log_callback, NULL); \
} \
return Qnil; \
} \
static void \
ruby_whisper_##log_queue##_end_proc(VALUE args) \
{ \
ruby_whisper_log_queue_close(&log_queue); \
VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \
if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \
rb_funcall(log_callback_thread, id_join, 0); \
} \
}
#define LOG_SETTABLE_INIT(log_queue, mod) \
ruby_whisper_log_queue_initialize(&log_queue); \
rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \
rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \
rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \
rb_extend_object(mod, mLogSettable); \
rb_funcall(mLogSettable, id_extended, 1, mod);
#endif

View File

@ -0,0 +1,49 @@
#include "ruby_whisper.h"
#include <stdio.h>
#include <unistd.h>
extern VALUE mParakeet;
extern VALUE mLogSettable;
extern VALUE cParakeetContext;
extern VALUE cParakeetSegment;
extern VALUE mOutputContext;
extern VALUE mOutputSegment;
extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet);
extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext);
extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet);
static ruby_whisper_log_queue parakeet_log_queue;
LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set)
static VALUE
ruby_whisper_parakeet_s_system_info_str(VALUE self)
{
return rb_str_new2(parakeet_print_system_info());
}
void
init_ruby_whisper_parakeet(VALUE *mWhisper)
{
mParakeet = rb_define_module_under(*mWhisper, "Parakeet");
rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version()));
LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet)
rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0);
init_ruby_whisper_parakeet_params(&mParakeet);
init_ruby_whisper_parakeet_token(&mParakeet);
init_ruby_whisper_parakeet_segment(&mParakeet);
cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet);
init_ruby_whisper_parakeet_context_params(&cParakeetContext);
init_ruby_whisper_parakeet_model(&mParakeet);
rb_include_module(cParakeetContext, mOutputContext);
rb_include_module(cParakeetSegment, mOutputSegment);
}

View File

@ -0,0 +1,304 @@
#include "ruby_whisper.h"
#define ITERATE_SEGMENT_ATTRS(ITERATOR) \
ITERATOR(get_segment_t0, LONG) \
ITERATOR(get_segment_t1, LONG) \
ITERATOR(get_segment_text, STRING) \
ITERATOR(n_tokens, INT)
#define ITERATE_TOKEN_ATTRS(ITERATOR) \
ITERATOR(get_token_text, STRING) \
ITERATOR(get_token_id, INT) \
ITERATOR(get_token_p, FLOAT)
#define VAL_FROM_LONG(v) LONG2NUM(v)
#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v)
#define VAL_FROM_INT(v) INT2NUM(v)
#define VAL_FROM_FLOAT(v) DBL2NUM(v)
#define READER(type) VAL_FROM_##type
extern ID id_to_s;
extern ID id___method__;
extern ID id_to_enum;
extern ID id_new;
extern VALUE cParakeetContext;
extern VALUE eError;
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples);
extern VALUE release_samples(VALUE rb_parsed_args);
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
extern rb_data_type_t ruby_whisper_parakeet_params_type;
extern rb_data_type_t ruby_whisper_parakeet_context_params_type;
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context);
static void
ruby_whisper_parakeet_context_free(void *p)
{
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
if (rwpc->context) {
parakeet_free(rwpc->context);
rwpc->context = NULL;
}
xfree(rwpc);
}
static size_t
ruby_whisper_parakeet_context_memsize(const void *p)
{
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
if (!rwpc) {
return 0;
}
size_t size = sizeof(*rwpc);
return size;
}
const rb_data_type_t ruby_whisper_parakeet_context_type = {
"ruby_whisper_parakeet_context",
{0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_context_allocate(VALUE klass)
{
ruby_whisper_parakeet_context *rwpc;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
rwpc->context = NULL;
return obj;
}
typedef struct {
struct parakeet_context **context;
char *model_path;
struct parakeet_context_params params;
} ruby_whisper_parakeet_context_init_args;
static void*
ruby_whisper_parakeet_context_init_without_gvl(void *args)
{
ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args;
*init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params);
return NULL;
}
static VALUE
ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper_parakeet_context *rwpc;
VALUE model_path;
VALUE context_params;
struct parakeet_context_params params;
rb_scan_args(argc, argv, "11", &model_path, &context_params);
TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
model_path = ruby_whisper_normalize_model_path(model_path);
if (!rb_respond_to(model_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context");
}
if (NIL_P(context_params)) {
params = parakeet_context_default_params();
} else {
ruby_whisper_parakeet_context_params *rwpcp;
GetParakeetContextParams(context_params, rwpcp);
params = rwpcp->params;
}
ruby_whisper_parakeet_context_init_args init_args = {
&rwpc->context,
StringValueCStr(model_path),
params,
};
rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL);
if (rwpc->context == NULL) {
rb_raise(rb_eRuntimeError, "Failed to load model");
}
return Qnil;
}
static VALUE
ruby_whisper_parakeet_context_full_n_segments(VALUE self)
{
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
return INT2NUM(parakeet_full_n_segments(rwpc->context));
}
#define DEF_SEGMENT_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \
{ \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(self, rwpc); \
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \
}
ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR)
#define DEF_TOKEN_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \
{ \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(self, rwpc); \
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \
}
ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR)
static VALUE
ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token)
{
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token));
return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data);
}
static VALUE
ruby_whisper_parakeet_context_each_segment(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
const int n_segments = parakeet_full_n_segments(rwpc->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
}
return self;
}
typedef struct {
struct parakeet_context *context;
struct parakeet_full_params *params;
float *samples;
int n_samples;
int result;
} parakeet_full_without_gvl_args;
static void*
parakeet_full_without_gvl(void *rb_args)
{
parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args;
args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples);
return NULL;
}
typedef struct {
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
} parakeet_full_ubf_args;
static void
parakeet_full_ubf(void *rb_args)
{
parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args;
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
}
VALUE
ruby_whisper_parakeet_context_full_body(VALUE rb_args)
{
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(*args->context, rwpc);
ruby_whisper_parakeet_params *rwpp;
GetParakeetParams(*args->params, rwpp);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data);
parakeet_full_without_gvl_args full_without_gvl_args = {
rwpc->context,
&rwpp->params,
args->samples,
args->n_samples,
0
};
parakeet_full_ubf_args full_ubf_args = {
&abort_callback_user_data,
};
rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_without_gvl_args.result);
}
static VALUE
ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self)
{
if (argc < 2 || argc > 3) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
VALUE n_samples = argc == 2 ? Qnil : argv[2];
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
ruby_whisper_full_args args = {
&self,
&argv[0],
parsed.samples,
parsed.n_samples,
};
VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
const int result = NUM2INT(rb_result);
if (result == 0) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
}
}
static VALUE
ruby_whisper_parakeet_context_get_model(VALUE self)
{
return ruby_whisper_parakeet_model_s_new(self);
}
VALUE
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
{
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0);
rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2);
rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0);
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1);
#define REGISTER_SEGMENT_ATTR(name, type) \
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1);
ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR)
#define REGISTER_TOKEN_ATTR(name, type) \
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2);
ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR)
return cParakeetContext;
}

View File

@ -0,0 +1,117 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(use_gpu, BOOL) \
ITERATOR(gpu_device, INT)
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
#define VAL_TO_BOOL(v) (RTEST(v))
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_TO_INT(v) (NUM2INT(v))
#define READER(type) VAL_FROM_##type
#define WRITER(type) VAL_TO_##type
#define DEF_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_context_params *rwpcp; \
GetParakeetContextParams(self, rwpcp); \
return READER(type)(rwpcp->params.name); \
} \
static VALUE \
ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_context_params *rwpcp; \
GetParakeetContextParams(self, rwpcp); \
rwpcp->params.name = WRITER(type)(val); \
return val; \
}
enum {
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name,
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS
};
extern VALUE cParakeetContextParams;
typedef VALUE (*param_writer_t)(VALUE, VALUE);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
static size_t
ruby_whisper_parakeet_context_params_memsize(const void *p)
{
if (!p) {
return 0;
}
return sizeof(ruby_whisper_parakeet_context_params);
}
const rb_data_type_t ruby_whisper_parakeet_context_params_type = {
"ruby_whisper_parakeet_context_params",
{0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,},
0, 0,
0,
};
static VALUE
ruby_whisper_parakeet_context_params_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_context_params *rwpcp;
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
}
static VALUE
ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_parakeet_context_params *rwpcp;
int i;
TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
rwpcp->params = parakeet_context_default_params();
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return Qnil;
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values);
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) {
value = values[i];
if (value == Qundef) {
continue;
}
param_writers[i](self, value);
}
return Qnil;
}
ITERATE_ATTRS(DEF_ATTR)
void
init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext)
{
cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject);
rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate);
rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1);
int i = 0;
#define REGISTER_ATTR(name, type) \
param_names[i] = rb_intern(#name); \
param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \
rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \
rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \
i++;
ITERATE_ATTRS(REGISTER_ATTR)
}

View File

@ -0,0 +1,84 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(n_vocab) \
ITERATOR(n_audio_ctx) \
ITERATOR(n_audio_state) \
ITERATOR(n_audio_head) \
ITERATOR(n_audio_layer) \
ITERATOR(n_mels) \
ITERATOR(ftype)
extern rb_data_type_t ruby_whisper_parakeet_context_type;
extern VALUE cParakeetModel;
static void
ruby_whisper_parakeet_model_mark(void *p)
{
ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p;
if (!NIL_P(rwpm->context)) {
rb_gc_mark(rwpm->context);
}
}
static size_t
ruby_whisper_parakeet_model_memsize(const void *p)
{
if (!p) {
return 0;
}
return sizeof(ruby_whisper_parakeet_model);
}
static const rb_data_type_t ruby_whisper_parakeet_model_type = {
"ruby_whisper_parakeet_model",
{ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_model_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_model *rwpm;
VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
rwpm->context = Qnil;
return model;
}
VALUE
ruby_whisper_parakeet_model_s_new(VALUE context)
{
const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel);
ruby_whisper_parakeet_model *rwpm;
TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
rwpm->context = context;
return model;
}
#define DEF_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_model_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_model *rwpm; \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetModel(self, rwpm); \
GetParakeetContext(rwpm->context, rwpc); \
return INT2NUM(parakeet_model_##name(rwpc->context)); \
}
ITERATE_ATTRS(DEF_ATTR)
void
init_ruby_whisper_parakeet_model(VALUE *mParakeet)
{
cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject);
rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate);
#define REGISTER_ATTR(name) \
rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0);
ITERATE_ATTRS(REGISTER_ATTR)
}

View File

@ -0,0 +1,548 @@
#include "ruby_whisper.h"
#define ITERATE_PARAMS(ITERATOR) \
ITERATOR(n_threads, INT) \
ITERATOR(offset_ms, INT) \
ITERATOR(duration_ms, INT) \
ITERATOR(no_context, BOOL) \
ITERATOR(audio_ctx, INT)
#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \
ITERATOR(new_segment, DATA) \
ITERATOR(new_token, DATA) \
ITERATOR(progress, DATA) \
ITERATOR(encoder_begin, DATA)
#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback)
#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR)
#define ITERATE_CALLBACK_PARAMS(ITERATOR) \
ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATOR(abort_callback)
enum {
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name,
#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name,
#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data,
ITERATE_PARAMS(DEF_IDX)
ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK)
ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA)
RUBY_WHISPER_PARAKEET_NUM_PARAMS
};
#define VAL_TO_INT(v) (NUM2INT(v))
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_TO_BOOL(v) (RTEST(v))
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
extern VALUE cParakeetParams;
extern ID id_call;
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef VALUE (*param_writer_t)(VALUE, VALUE);
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int n_new;
} call_parakeet_new_segment_callbacks_args;
static void*
call_parakeet_new_segment_callbacks(void *v_args)
{
call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
const int n_segments = parakeet_full_n_segments_from_state(args->state);
for (int i = args->n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment);
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
return NULL;
}
static void
ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_segment_callbacks_args args = {
container,
state,
n_new,
};
rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_context *context;
struct parakeet_state *state;
const parakeet_token_data *token_data;
} call_parakeet_new_token_callbacks_args;
static void*
call_parakeet_new_token_callbacks(void *v_args)
{
call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args;
VALUE token = Qnil;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
if (NIL_P(token)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
}
for (int i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, token);
}
return NULL;
}
static void
ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_token_callbacks_args args = {
container,
context,
state,
token_data,
};
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int progress;
} call_parakeet_progress_callbacks_args;
static void*
call_parakeet_progress_callback(void *v_args)
{
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
}
return NULL;
}
static void
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_progress_callbacks_args args = {
container,
state,
progress,
};
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
bool is_continued;
} call_parakeet_encoder_begin_callbacks_args;
static void*
call_parakeet_encoder_begin_callbacks(void *v_args)
{
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
return NULL;
}
static bool
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return true;
}
call_parakeet_encoder_begin_callbacks_args args = {
container,
state,
true,
};
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
return args.is_continued;
}
typedef struct {
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_parakeet_abort_callbacks_args;
static void*
call_parakeet_abort_callbacks(void *v_args)
{
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
VALUE cb;
for (long i = 0; i < n_callbacks; i++) {
cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
return NULL;
}
static bool
ruby_whisper_parakeet_abort_callback(void *user_data)
{
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_parakeet_abort_callbacks_args args = {
data->callback_container,
false,
};
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
return args.is_interrupted;
}
#define CALLBACK_CONTAINER_NAME(name) name ## _container
void
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
{
#define PARAM_NAME(name) name
#define USER_DATA_NAME(name) name##_user_data
#define REGISTER_CALLBACK(name) \
if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \
rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \
rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \
rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \
}
ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK)
if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) {
abort_callback_user_data->callback_container = rwpp->abort_callback_container;
}
rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback;
rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
static void
ruby_whisper_parakeet_params_mark(void *p)
{
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
#define MARK_CONTAINER(name) \
if (rwpp->name##_container) { \
ruby_whisper_callback_container_mark(rwpp->name##_container); \
}
ITERATE_CALLBACK_PARAMS(MARK_CONTAINER)
}
static void
ruby_whisper_parakeet_params_free(void *p)
{
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
#define FREE_CONTAINER(name) \
if (rwpp->name##_container) { \
xfree(rwpp->name##_container); \
}
ITERATE_CALLBACK_PARAMS(FREE_CONTAINER)
xfree(rwpp);
}
static size_t
ruby_whisper_parakeet_params_memsize(const void *p)
{
const struct ruby_whisper_parakeet_params *params = p;
if (!params) {
return 0;
}
return sizeof(ruby_whisper_parakeet_params);
}
const rb_data_type_t ruby_whisper_parakeet_params_type = {
"ruby_whisper_parakeet_params",
{ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,},
0, 0,
0
};
#define READER(type) VAL_FROM_##type
#define WRITER(type) VAL_TO_##type
#define DEF_PARAM_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return READER(type)(rwpp->params.name); \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->params.name = WRITER(type)(val); \
return val; \
}
#define DEF_CALLBACK_PARAM_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \
return val; \
}
#define DEF_USER_DATA_PARAM_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \
return val; \
}
#define DEF_HOOK(name, data) \
static VALUE \
ruby_whisper_parakeet_params_on_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
const VALUE blk = rb_block_proc(); \
if (NIL_P(rwpp->name##_callback_container->callbacks)) { \
rwpp->name##_callback_container->callbacks = rb_ary_new(); \
} \
rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \
return Qnil; \
}
ITERATE_PARAMS(DEF_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR)
ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
static VALUE
ruby_whisper_parakeet_params_abort_on(VALUE self)
{
ruby_whisper_parakeet_params *rwpp;
GetParakeetParams(self, rwpp);
const VALUE blk = rb_block_proc();
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
rwpp->abort_callback_container->callbacks = rb_ary_new();
}
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
return Qnil;
}
static VALUE
ruby_whisper_parakeet_params_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_params *rwpp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
return obj;
}
static VALUE
ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_parakeet_params *rwpp;
int i;
TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate();
ITERATE_CALLBACK_PARAMS(INIT_CONTAINER)
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return Qnil;
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values);
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) {
value = values[i];
if (value == Qundef) {
continue;
}
param_writers[i](self, value);
}
return Qnil;
}
void
init_ruby_whisper_parakeet_params(VALUE *mParakeet)
{
cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate);
rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1);
int i = 0;
#define REGISTER_PARAM(name) \
param_names[i] = rb_intern(#name); \
param_writers[i] = ruby_whisper_parakeet_params_set_##name; \
rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \
rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \
i++;
#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name)
#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name)
#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data)
ITERATE_PARAMS(REGISTER_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR)
#define REGISTER_HOOK(name, data) \
rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0);
ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _)
rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0);
}

View File

@ -0,0 +1,157 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(start_time, t0, TIME) \
ITERATOR(end_time, t1, TIME) \
ITERATOR(text, text, STRING)
enum {
#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name,
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
};
#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10))
#define VAL_FROM_STRING(v) (rb_str_new2(v))
#define READER(type) VAL_FROM_##type
#define DEF_ATTR(rb_name, c_name, type) \
static VALUE \
ruby_whisper_parakeet_get_##rb_name(VALUE self) \
{ \
ruby_whisper_parakeet_segment *rwps; \
GetParakeetSegment(self, rwps); \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(rwps->context, rwpc); \
return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \
}
extern ID id___method__;
extern ID id_to_enum;
extern VALUE cParakeetSegment;
extern VALUE sym_start_time;
extern VALUE sym_end_time;
extern VALUE sym_text;
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token);
static void
rb_whisper_parakeet_segment_mark(void *p)
{
ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p;
rb_gc_mark(rwps->context);
}
static size_t
ruby_whisper_parakeet_segment_memsize(const void *p)
{
const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p;
if (!rwps) {
return 0;
}
return sizeof(*rwps);
}
static const rb_data_type_t ruby_whisper_parakeet_segment_type = {
"ruby_whisper_parakeet_segment",
{rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_segment_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_segment *rwps;
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
}
VALUE
ruby_whisper_parakeet_segment_init(VALUE context, int index)
{
ruby_whisper_parakeet_segment *rwps;
const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment);
TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
rwps->context = context;
rwps->index = index;
return segment;
}
ITERATE_ATTRS(DEF_ATTR)
static VALUE
ruby_whisper_parakeet_segment_each_token(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper_parakeet_segment *rwps;
GetParakeetSegment(self, rwps);
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(rwps->context, rwpc);
const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index);
for (int i = 0; i < n_tokens; i++) {
rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i));
}
return self;
}
static VALUE
ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_parakeet_segment *rwps;
GetParakeetSegment(self, rwps);
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(rwps->context, rwpc);
VALUE hash = rb_hash_new();
long n_keys;
if (NIL_P(keys)) {
keys = rb_ary_new3(
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
sym_start_time,
sym_end_time,
sym_text
);
n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) {
return hash;
}
}
for (int i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
#define CHECK_AND_SET_KEY(rb_name, c_name, type) \
if (key == sym_##rb_name) { \
rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \
}
ITERATE_ATTRS(CHECK_AND_SET_KEY)
}
return hash;
}
void
init_ruby_whisper_parakeet_segment(VALUE *mParakeet)
{
cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject);
rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate);
#define REGISTER_ATTR(rb_name, c_name, type) \
rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0);
ITERATE_ATTRS(REGISTER_ATTR)
rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0);
rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1);
}

View File

@ -0,0 +1,188 @@
#include "ruby_whisper.h"
#define ITERATE_MEMBERS(ITERATOR) \
ITERATOR(id, id, id, id, INT) \
ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \
ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \
ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \
ITERATOR(probability, probability, p, p, FLOAT) \
ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \
ITERATOR(start_time, start_time, start_time, t0, TIME) \
ITERATOR(end_time, end_time, end_time, t1, TIME) \
ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL)
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(text, text, text, text, STRING)
enum {
#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name,
ITERATE_MEMBERS(DEF_IDX)
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS,
};
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_FROM_FLOAT(v) (DBL2NUM(v))
#define VAL_FROM_TIME(v) (LONG2NUM(v * 10))
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
#define VAL_FROM_STRING(v) (rb_str_new2(v))
#define READER(type) VAL_FROM_##type
#define MEMBER_NAME(name) name
#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \
}
#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return rwpt->p_name; \
}
VALUE cParakeetToken;
#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key;
ITERATE_MEMBERS(DEC_ATTR_SYMS)
ITERATE_ATTRS(DEC_ATTR_SYMS)
static void
ruby_whisper_parakeet_token_mark(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
rb_gc_mark(rwpt->text);
}
static void
ruby_whisper_parakeet_token_free(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (rwpt->token_data) {
xfree(rwpt->token_data);
rwpt->token_data = NULL;
}
xfree(rwpt);
}
static size_t
ruby_whisper_parakeet_token_memsize(const void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (!rwpt) {
return 0;
}
size_t size = sizeof(*rwpt);
if (rwpt->token_data) {
size += sizeof(*rwpt->token_data);
}
return size;
}
static const rb_data_type_t ruby_whisper_parakeet_token_type = {
"ruby_whisper_parakeet_token",
{ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize},
0, 0,
0,
};
static VALUE
ruby_whisper_parakeet_token_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_token *rwpt;
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = NULL;
rwpt->text = Qnil;
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
{
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
ruby_whisper_parakeet_token *rwpt;
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = ALLOC(parakeet_token_data);
*rwpt->token_data = *token_data;
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
{
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
}
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
// Define #text using parakeet_token_to_str or parakeet_token_to_text
ITERATE_ATTRS(DEF_ATTR)
static VALUE
ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_parakeet_token *rwpt;
GetParakeetToken(self, rwpt);
VALUE hash = rb_hash_new();
long n_keys = 0;
if (NIL_P(keys)) {
VALUE attrs[] = {
#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key,
ITERATE_MEMBERS(LIST_SYMS)
ITERATE_ATTRS(LIST_SYMS)
};
keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs);
n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) {
return hash;
}
}
for (long i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \
if (key == sym_##s_key) { \
rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \
}
ITERATE_MEMBERS(CHECK_AND_SET_KEY)
ITERATE_ATTRS(CHECK_AND_SET_KEY)
}
return hash;
}
void
init_ruby_whisper_parakeet_token(VALUE *mParakeet)
{
cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject);
rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate);
#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \
sym_##s_key = ID2SYM(rb_intern(#s_key)); \
rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0);
ITERATE_MEMBERS(REGISTER_ATTR)
ITERATE_ATTRS(REGISTER_ATTR)
rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1);
}

View File

@ -0,0 +1,58 @@
#include "ruby_whisper.h"
#include "common-whisper.h"
#include <string>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern const rb_data_type_t ruby_whisper_parakeet_params_type;
extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args);
extern ID id_to_path;
extern ID id_new;
extern VALUE eError;
VALUE
ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params)
{
if (rb_respond_to(audio_path, id_to_path)) {
audio_path = rb_funcall(audio_path, id_to_path, 0);
}
std::string fname = StringValueCStr(audio_path);
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
if (!read_audio_data(fname, pcmf32, pcmf32s, false)) {
rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str());
return Qnil;
}
ruby_whisper_parakeet_context *rwpc;
ruby_whisper_parakeet_params *rwpp;
GetParakeetContext(self, rwpc);
GetParakeetParams(params, rwpp);
ruby_whisper_full_args args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
};
VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args);
const int result = NUM2INT(rb_result);
if (result == 0) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
}
}
#ifdef __cplusplus
}
#endif

View File

@ -76,8 +76,8 @@ static ID id_vad;
static ID id_vad_model_path;
static ID id_vad_params;
static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
void
ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc)
{
if (rwc == NULL) return;
@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_callback_container*
rb_whisper_callback_container_allocate() {
ruby_whisper_callback_container*
ruby_whisper_callback_container_allocate() {
ruby_whisper_callback_container *container;
container = ALLOC(ruby_whisper_callback_container);
container->context = NULL;
@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() {
return container;
}
static void
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
{
if (rwc == NULL) return;
rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_abort_callback_container*
rb_whisper_abort_callback_container_allocate() {
ruby_whisper_abort_callback_container *container;
container = ALLOC(ruby_whisper_abort_callback_container);
container->context = NULL;
container->user_data = Qnil;
container->callback = Qnil;
container->callbacks = Qnil;
container->is_interrupted = false;
return container;
}
static bool
bool
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
static bool
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct whisper_state *state;
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
}
typedef struct {
const ruby_whisper_abort_callback_container *container;
struct whisper_state *state;
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_abort_callbacks_args;
static void*
call_abort_callbacks(void *v_args) {
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
const ruby_whisper_abort_callback_container *container = args->container;
if (container->is_interrupted) {
args->is_interrupted = true;
return NULL;
}
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
if (NIL_P(container->callbacks)) {
return NULL;
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (0 == n_callbacks) {
return NULL;
}
for (int j = 0; j < callbacks_len; j++) {
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
VALUE result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
}
static bool abort_callback(void * user_data) {
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
if (container->is_interrupted) {
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
if (!ruby_whisper_abort_callback_container_is_present(container)) {
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_abort_callbacks_args args = {
container,
NULL,
data->callback_container,
false
};
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
return;
}
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription");
}
if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) {
rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription");
}
// new_segment_callback is called only after multiple threads are joined
// progress_callback is not called when parallel
if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) {
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
}
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
}
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
if (!NIL_P(log_callback)) {
rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription");
}
}
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback;
@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
abort_callback_user_data->callback_container = rwp->abort_callback_container;
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
rwp->abort_callback_container->is_interrupted = false;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
static void set_vad_params(ruby_whisper_params *rwp)
@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp)
rwp->params.vad_params = rwvp->params;
}
/*
TODO: Set abort callback to trap SIGINT and SIGTERM
*/
void
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
{
check_thread_safety(rwp, n_processors);
register_callbacks(rwp, context);
register_callbacks(rwp, context, abort_callback_user_data);
set_vad_params(rwp);
}
@ -421,10 +376,10 @@ void
rb_whisper_params_mark(void *p)
{
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
rb_gc_mark(rwp->vad_params);
}
@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass)
}
rwp->diarize = false;
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
return obj;
}

View File

@ -4,12 +4,12 @@
extern ID id___method__;
extern ID id_to_enum;
static VALUE sym_start_time;
static VALUE sym_end_time;
static VALUE sym_text;
static VALUE sym_no_speech_prob;
static VALUE sym_speaker_turn_next;
static VALUE sym_n_tokens;
VALUE sym_start_time;
VALUE sym_end_time;
VALUE sym_text;
VALUE sym_no_speech_prob;
VALUE sym_speaker_turn_next;
VALUE sym_n_tokens;
extern const rb_data_type_t ruby_whisper_type;

View File

@ -16,6 +16,8 @@ extern ID id_to_path;
extern ID transcribe_option_names[1];
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
extern VALUE full_body(VALUE rb_args);
extern VALUE full_parallel_body(VALUE rb_args);
typedef struct{
struct whisper_context *context;
@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args)
return NULL;
}
typedef struct {
ruby_whisper_abort_callback_container *abort_callback_container;
} transcribe_ubf_args;
static void
transcribe_ubf(void *rb_args)
{
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
}
/*
* transcribe a single file
* can emit to a block results
@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
// Commented out because it is work in progress
// {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data;
// return !is_aborted;
// };
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
prepare_transcription(rwp, &self, n_processors);
transcribe_without_gvl_args args = {
rw->context,
&rwp->params,
pcmf32.data(),
pcmf32.size(),
n_processors,
0,
};
transcribe_ubf_args ubf_args = {
rwp->abort_callback_container,
};
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
if (args.result != 0) {
VALUE rb_result;
if (n_processors == 1) {
ruby_whisper_full_args args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
};
rb_result = full_body((VALUE)&args);
} else {
ruby_whisper_full_parallel_args parallel_args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
n_processors,
};
rb_result = full_parallel_body((VALUE)&parallel_args);
}
const int result = NUM2INT(rb_result);
if (result != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}

View File

@ -1,15 +0,0 @@
module Whisper
class Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
end

View File

@ -0,0 +1,36 @@
require "mutex_m"
module Whisper
module LogSettable
class << self
def extended(base)
base.extend Mutex_m
end
end
private
def start_log_callback_thread
return if @log_callback_thread&.alive?
@log_callback_thread = Thread.new {
begin
while logs = drain_logs
begin
callback, user_data = synchronize {[@log_callback, @log_callback_user_data]}
next if callback.nil?
logs.each do |(level, text)|
callback.call level, text, user_data
end
rescue => err
$stderr.puts err
end
end
rescue => err
$stderr.puts err
end
}
end
end
end

View File

@ -41,6 +41,8 @@ module Whisper
def cache
path = cache_path
return path if cache_path.exist?
headers = {}
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
request @uri, headers
@ -216,8 +218,18 @@ module Whisper
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
end
%w[
parakeet-tdt-0.6b-v3-f16
parakeet-tdt-0.6b-v3-f32
parakeet-tdt-0.6b-v3-q4_0
parakeet-tdt-0.6b-v3-q4_k
parakeet-tdt-0.6b-v3-q8_0
].each do |name|
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin")
end
@coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models|
next if name.end_with?("-tdrz") || name.start_with?("silero-")
next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-")
if matched = name.match(/\A(?<name>.*)-q\d_\d\z/)
name = matched[:name]

View File

@ -0,0 +1,74 @@
module Whisper
module Output
module Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
module Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end
end

View File

@ -1,58 +0,0 @@
module Whisper
class Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end

View File

@ -40,7 +40,21 @@ module Whisper
def self.log_set: (log_callback?, Object? user_data) -> log_callback
def self.system_info_str: () -> String
module Output
module Context
def to_srt: () -> String
def to_webvtt: () -> String
end
module Segment
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
end
end
class Context
include Output::Context
def self.new: (String | path | ::URI::HTTP) -> instance
# transcribe a single file
@ -139,17 +153,14 @@ module Whisper
| (Whisper::Params, _Samples, ?Integer n_samples) -> self
| (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
def to_srt: () -> String
def to_webvtt: () -> String
class Params
def self.new: (
use_gpu: boolish,
flash_attn: boolish,
gpu_device: Integer,
dtw_token_timestamps: boolish,
dtw_aheads_preset: Integer,
dtw_n_top: Integer | nil,
?use_gpu: boolish,
?flash_attn: boolish,
?gpu_device: Integer,
?dtw_token_timestamps: boolish,
?dtw_aheads_preset: Integer,
?dtw_n_top: Integer | nil,
) -> instance
def use_gpu=: (boolish) -> boolish
@ -444,6 +455,9 @@ module Whisper
def abort_on: { (Object user_data) -> boolish } -> void
end
module LogSettable
end
class Model
def self.pre_converted_models: () -> Hash[String, Model::URI]
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
@ -474,6 +488,8 @@ module Whisper
end
class Segment
include Output::Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
@ -514,9 +530,6 @@ module Whisper
#
def each_token: { (Token) -> void } -> void
| () -> Enumerator[Token]
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
# Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next`
#
@ -528,7 +541,7 @@ module Whisper
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys
end
module Token
class Token
type deconstructed_keys = {
id: (Integer | nil),
tid: (Integer | nil),
@ -598,6 +611,336 @@ module Whisper
def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys
end
module Parakeet
extend LogSettable
VERSION: String
# Control logging output. The default behavior is to print to stderr.
#
def self.log_set: (nil, Object? user_data) -> nil
| (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil
def self.system_info_str: () -> String
class Context
include Output::Context
# Load a Parakeet model from the given file path.
#
def self.new: (String | path | ::URI::HTTP, ?Params) -> instance
# Transcribe a single audio file.
#
def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
# Not thread safe for the same context.
#
# The second argument `samples` must be an array of samples, respond to `:length`,
# or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
#
def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self
| (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self
# Number of generated text segments.
#
def full_n_segments: () -> Integer
# Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer segment_index) -> Integer
# End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer segment_index) -> Integer
# Text of a segment indexed by `segment_index`.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer segment_index) -> String
# Number of tokens in the segment indexed by `segment_index`.
#
def full_n_tokens: (Integer segment_index) -> Integer
# Text of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_text: (Integer segment_index, Integer token_index) -> String
# Token id of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer
# Probability of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_p: (Integer segment_index, Integer token_index) -> Float
# Token data of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_data: (Integer segment_index, Integer token_index) -> Token
def model: () -> Model
# Yields each Whisper::Parakeet::Segment:
#
# parakeet.transcribe("path/to/audio.wav", params)
# parakeet.each_segment do |segment|
# puts segment.text
# end
#
# Returns an `Enumerator` if no block given:
#
# parakeet.transcribe("path/to/audio.wav", params)
# enum = parakeet.each_segment
# enum.to_a # => [#<Whisper::Parakeet::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
class Params
def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance
def use_gpu: () -> boolish
def use_gpu=: (boolish) -> boolish
def gpu_device: () -> Integer
def gpu_device=: (Integer) -> Integer
end
end
class Params
def self.new: (
?n_threads: Integer,
?offset_ms: Integer,
?duration_ms: Integer,
?no_context: boolish,
?audio_ctx: Integer,
?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void,
?new_segment_callback_user_data: Object,
?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void,
?new_token_callback_user_data: Object,
?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void,
?progress_callback_user_data: Object,
?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish,
?encoder_begin_callback_user_data: Object,
?abort_callback: ^(Object user_data) -> boolish,
?abort_callback_user_data: Object
) -> instance
# Number of threads to use.
#
def n_threads=: (Integer) -> Integer
def n_threads: () -> Integer
# Start offset in ms.
#
def offset_ms=: (Integer) -> Integer
def offset_ms: () -> Integer
# Audio duration to process in ms.
#
def duration_ms=: (Integer) -> Integer
def duration_ms: () -> Integer
# If `true`, does not use past transcription (if any) as context.
#
def no_context=: (boolish) -> boolish
def no_context: () -> (true | false)
# Overwrite the audio context size. `0` uses the default value.
#
def audio_ctx=: (Integer) -> Integer
def audio_ctx: () -> Integer
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void)
def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of new segment callback.
#
def new_segment_callback_user_data=: (Object?) -> Object?
def new_segment_callback_user_data: () -> Object?
# Sets token callback, called for every newly predicted token.
#
def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void)
def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of token callback.
#
def new_token_callback_user_data=: (Object?) -> Object?
def new_token_callback_user_data: () -> Object?
# Sets progress callback, called on each progress update.
#
# +progress+ is an Integer between 0 and 100.
#
def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void)
def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of progress callback.
#
def progress_callback_user_data=: (Object?) -> Object?
def progress_callback_user_data: () -> Object?
# Sets encoder begin callback, called each time before the encoder starts.
#
# If it returns `false`, the computation is aborted.
#
def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish)
def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object?) -> Object?
def encoder_begin_callback_user_data: () -> Object?
# Sets abort callback, called each time before ggml computation starts.
#
def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish)
def abort_callback: () -> ((^(Object user_data) -> boolish) | nil)
# Sets user data passed to the last argument of abort callback.
#
def abort_callback_user_data=: (Object?) -> Object?
def abort_callback_user_data: () -> Object?
# Hook called on new segment. Yields each Whisper::Parakeet::Segment.
#
def on_new_segment: { (Segment) -> void } -> void
# Hook called on new token. Yields each Whisper::Parakeet::Token.
#
def on_new_token: { (Token) -> void } -> void
# Hook called on progress update. Yields each progress `Integer` between 0 and 100.
#
def on_progress: { (Integer progress) -> void } -> void
# Hook called each time before the encoder starts.
#
def on_encoder_begin: { () -> boolish } -> void
# Call block to determine whether abort or not. Return `true` when you want to abort.
#
def abort_on: { () -> boolish } -> void
end
class Segment
include Output::Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
text: (String | nil)
}
# Start time in milliseconds.
#
def start_time: () -> Integer
# End time in milliseconds.
#
def end_time: () -> Integer
# Text of the segment.
#
def text: () -> String
# Yields each Whisper::Parakeet::Token:
#
# parakeet.each_segment.first.each_token do |token|
# p token
# end
#
# Returns an `Enumerator` if no block is given:
#
# parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...]
#
def each_token: { (Token) -> void } -> void
| () -> Enumerator[Token]
# Possible keys: `:start_time`, `:end_time`, `:text`
#
def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys
end
class Token
type deconstructed_keys = {
id: (Integer | nil),
duration_idx: (Integer | nil),
duration_value: (Integer | nil),
frame_index: (Integer | nil),
probability: (Float | nil),
log_probability: (Float | nil),
start_time: (Integer | nil),
end_time: (Integer | nil),
word_start: ((true | false) | nil),
text: (String | nil),
}
# Token ID.
#
def id: () -> Integer
# Index into the model's durations array.
#
def duration_idx: () -> Integer
# Actual duration value.
#
def duration_value: () -> Integer
# Frame index of the token.
#
def frame_index: () -> Integer
# Probability of the token.
#
def probability: () -> Float
# Log probability of the token.
#
def log_probability: () -> Float
# Start time of the token in milliseconds.
#
def start_time: () -> Integer
# End time of the token in milliseconds.
#
def end_time: () -> Integer
# Whether this token is the start of a word.
#
def word_start?: () -> (true | false)
# Get the token text of the token.
#
def text: () -> String
def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys
end
class Model
def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer
def n_audio_state: () -> Integer
def n_audio_head: () -> Integer
def n_audio_layer: () -> Integer
def n_mels: () -> Integer
def ftype: () -> Integer
end
end
module VAD
class Params
def self.new: (

View File

@ -5,6 +5,8 @@ require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
Parakeet = Whisper::Parakeet
class << self
def whisper
return @whisper if @whisper

View File

@ -129,6 +129,7 @@ class TestCallback < TestBase
return false
}
@whisper.transcribe(@audio, @params)
sleep 0.5 # wait for logs dequeued
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end

View File

@ -0,0 +1,28 @@
require_relative "helper"
require "stringio"
class TestParakeet < TestBase
def test_log_set
log_callback = Parakeet.instance_variable_get("@log_callback")
user_data = Parakeet.instance_variable_get("@log_callback_user_data")
$stdout = StringIO.new
Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil
Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
sleep 0.1
$stdout.rewind
logs = $stdout.string
assert_match /loading model from/, logs
ensure
$stdout = STDOUT
Parakeet.log_set log_callback, user_data
end
def test_system_info_str
assert_match /\APARAKEET : /, Parakeet.system_info_str
end
def test_version
assert_instance_of String, Parakeet::VERSION
end
end

View File

@ -0,0 +1,107 @@
require_relative "helper"
class TestParakeetCallback < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
@params = Parakeet::Params.new
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
end
def test_new_segment_callback
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_kind_of Integer, n_new
assert n_new > 0
assert_same @parakeet, context
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
start_time = context.full_get_segment_t0(i_segment) * 10
end_time = context.full_get_segment_t1(i_segment) * 10
text = context.full_get_segment_text(i_segment)
assert_kind_of Integer, start_time
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0
end
}
@parakeet.transcribe AUDIO, @params
end
def test_on_new_segment
seg = nil
index = 0
@params.on_new_segment do |segment|
assert_instance_of Parakeet::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text)
end
index += 1
end
@parakeet.transcribe AUDIO, @params
assert_equal 0, seg.start_time
assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text
end
def test_on_new_token
index = 0
@params.on_new_token do |token|
assert_instance_of Parakeet::Token, token
if index == 0
assert_instance_of Integer, token.start_time
assert_match "▁And", token.text
end
index += 1
end
@parakeet.transcribe AUDIO, @params
end
def test_on_progress
first = nil
@params.on_progress do |progress|
assert_kind_of Integer, progress
assert 0 <= progress && progress <= 100
first = progress if first.nil?
end
@parakeet.transcribe AUDIO, @params
assert_equal 0, first
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@parakeet.transcribe AUDIO, @params
assert i > 0
end
def test_abort_on
do_abort = false
@params.on_new_segment do |segment|
do_abort = true if segment.text.match?(/ask/)
end
i = 0
@params.abort_on do
i += 1
do_abort
end
@parakeet.transcribe(AUDIO, @params) rescue nil
assert i > 0
end
end

View File

@ -0,0 +1,116 @@
require_relative "helper"
require "stringio"
class TestParakeetContext < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
@params = Parakeet::Params.new
end
def test_new
assert_instance_of Parakeet::Context, @parakeet
end
def test_new_with_params
log_callback = Parakeet.instance_variable_get(:@log_callback)
user_data = Parakeet.instance_variable_get(:@log_callback_user_data)
begin
logs = ""
Parakeet.log_set proc {|level, message| logs << message}, nil
params = Parakeet::Context::Params.new(use_gpu: false)
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params)
assert_instance_of Parakeet::Context, parakeet
assert_match /use gpu\s+=\s+0/, logs
ensure
Parakeet.log_set log_callback, user_data
end
end
sub_test_case "full" do
def setup
super
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
end
def test_full
@parakeet.full @params, @samples, @samples.length
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text
end
def test_full_without_length
@parakeet.full(@params, @samples)
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_enumerator
samples = @samples.each
@parakeet.full @params, samples, @samples.length
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_enumerator_without_length
samples = @samples.each
assert_raise ArgumentError do
@parakeet.full @params, samples
end
end
def test_full_enumerator_with_too_large_length
samples = @samples.each.take(10).to_enum
assert_raise StopIteration do
@parakeet.full @params, samples, 11
end
end
def test_full_with_memory_view
samples = JFKReader.new(AUDIO)
@parakeet.full @params, samples
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_with_memroy_view_gc
samples = JFKReader.new(AUDIO)
@parakeet.full(@params, samples)
GC.start
require "fiddle"
Fiddle::MemoryView.export samples do |view|
assert_equal 176000, view.to_s.unpack("#{view.format}*").length
end
end
end
def test_transcribe
assert_nothing_raised do
@parakeet.transcribe AUDIO, @params
end
end
def test_transcribe_with_pathname
assert_nothing_raised do
@parakeet.transcribe Pathname(AUDIO), @params
end
end
def test_transcribe_with_nothing
assert_raise_message(/open/) do
@parakeet.transcribe "nothing", @params
end
end
end

View File

@ -0,0 +1,24 @@
require_relative "helper"
class TestParakeetContextParams < TestBase
def setup
@params = Parakeet::Context::Params.new
end
def test_new
assert_instance_of Parakeet::Context::Params, @params
end
def test_attributes
assert_true @params.use_gpu
assert_instance_of Integer, @params.gpu_device
end
def test_attribute_writer
@params.use_gpu = false
assert_false @params.use_gpu
@params.gpu_device = 2
assert_equal 2, @params.gpu_device
end
end

View File

@ -0,0 +1,21 @@
require_relative "helper"
class TestParakeetModel < TestBase
def test_model
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
assert_instance_of Parakeet::Model, parakeet.model
end
def test_attributes
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
model = parakeet.model
assert_equal 10, model.n_vocab
assert_equal 3200, model.n_audio_ctx
assert_equal 8, model.n_audio_state
assert_equal 2, model.n_audio_head
assert_equal 1, model.n_audio_layer
assert_equal 16, model.n_mels
assert_equal 0, model.ftype
end
end

View File

@ -0,0 +1,78 @@
require_relative "helper"
require "etc"
class TestParakeetParams < TestBase
PARAM_NAMES = [
:n_threads,
:offset_ms,
:duration_ms,
:no_context,
:audio_ctx
]
def setup
@params = Parakeet::Params.new
end
def test_new
assert_instance_of Parakeet::Params, @params
end
def test_n_threads
assert_equal [4, Etc.nprocessors].min, @params.n_threads
@params.n_threads = 1
assert_equal 1, @params.n_threads
end
def test_offset_ms
assert_equal 0, @params.offset_ms
@params.offset_ms = 10_000
assert_equal 10_000, @params.offset_ms
end
def test_duration_ms
assert_equal 0, @params.duration_ms
@params.duration_ms = 60_000
assert_equal 60_000, @params.duration_ms
end
def test_no_context
assert_equal true, @params.no_context
@params.no_context = false
assert_equal false, @params.no_context
end
def test_audio_ctx
assert_equal 0, @params.audio_ctx
@params.audio_ctx = 1
assert_equal 1, @params.audio_ctx
end
def test_new_with_kw_args
params = Parakeet::Params.new(n_threads: 1)
assert_equal 1, params.n_threads
assert_equal 0, params.offset_ms
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_value = @params.send(param)
value = case [param, default_value]
in [*, true | false]
!default_value
in [*, Integer]
default_value + 1
end
params = Parakeet::Params.new(param => value)
assert_equal value, params.send(param)
PARAM_NAMES.reject {|name| name == param}.each do |name|
assert_equal @params.send(name), params.send(name)
end
end
end

View File

@ -0,0 +1,42 @@
require_relative "helper"
class TestParakeetSegment < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
@parakeet.transcribe AUDIO, Parakeet::Params.new
end
def test_segment
whole_text = ""
@parakeet.each_segment do |segment|
assert_instance_of Parakeet::Segment, segment
assert_kind_of Integer, segment.start_time
assert segment.end_time >= segment.start_time
assert_kind_of String, segment.text
whole_text << segment.text
end
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text)
end
def test_deconstruct_keys
segment = @parakeet.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text
}
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text])
end
def test_deconstruct_keys_with_nil
segment = @parakeet.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text
}
assert_equal expected, segment.deconstruct_keys(nil)
end
end

View File

@ -0,0 +1,73 @@
require_relative "helper"
class TestParakeetToken < TestBase
ATTRS = %i[
id
duration_idx
duration_value
frame_index
probability
log_probability
start_time
end_time
word_start?
text
]
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
params = Parakeet::Params.new
parakeet.transcribe AUDIO, params
@segment = parakeet.each_segment.first
end
def test_each_token
i = 0
@segment.each_token do |token|
i += 1
assert_instance_of Parakeet::Token, token
end
assert_equal 38, i
end
def test_each_token_without_block
assert_instance_of Enumerator, @segment.each_token
end
def test_token
token = @segment.each_token.first
assert_instance_of Parakeet::Token, token
assert_instance_of Integer, token.id
assert_instance_of Integer, token.duration_idx
assert_instance_of Integer, token.duration_value
assert_instance_of Integer, token.frame_index
assert_instance_of Float, token.probability
assert_instance_of Float, token.log_probability
assert_instance_of Integer, token.start_time
assert_instance_of Integer, token.end_time
assert_instance_of String, token.text
end
def test_text
assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."],
@segment.each_token.collect(&:text)
end
def test_deconstruct_keys_with_nil
token = @segment.each_token.first
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
assert_equal expected, token.deconstruct_keys(nil)
end
def test_deconstruct_keys_with_keys
token = @segment.each_token.first
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
assert_equal expected, token.deconstruct_keys(expected.keys)
end
end

View File

@ -9,7 +9,7 @@ class TestVADSegment < TestBase
end
assert_raise do
segments.end_time
segment.end_time
end
assert_raise do

View File

@ -149,6 +149,7 @@ class TestWhisper < TestBase
}
Whisper.log_set log_callback, user_data
Whisper::Context.new("base.en")
sleep 0.1 # wait for logs dequeued
assert logs.length > 30
logs.each do |log|

View File

@ -23,7 +23,7 @@ Gem::Specification.new do |s|
s.test_files = s.files.select {|file| file.start_with? "test/"}
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'
s.required_ruby_version = '>= 3.3.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggml-org/whisper.cpp'