ruby : add `Whisper::Context::Params`, fix token memory management (#3647)

* Don't convert to temporary VALUE

* Define Whisper::Context::Params

* Add test for Whisper::Context::Params

* Implement Whisper::Context::Params

* Add tests for Context::Params

* Fix Whisper::Token memory management

* Add test for token_timestamps

* Make Context accept Context::Params

* Make Context::Params.new accept keyword args

* Add test for Context::Params.new with keyword args

* Add signature of Context::Params

* Add example for Whisper::Token

* Fix typos

* Revert "Don't convert to temporary VALUE"

This reverts commit dee66e7384.

* Hold Token#text as Ruby objectd

* Don't use pointer for ruby_whisper_context_params.params

* Use RUBY_DEFAULT_FREE instead of custom function

* Update bindings/ruby/README.md

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>

* Add document for Whisper::Context::Params

---------

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
KITAITI Makoto 2026-02-04 20:33:09 +09:00 committed by GitHub
parent aa1bc0d1a6
commit 941bdabbe4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 435 additions and 15 deletions

View File

@ -247,6 +247,58 @@ whisper.transcribe("path/to/audio.wav", params)
```
### Tokens ###
Each segment has tokens.
To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`.
```ruby
whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new(token_timestamps: true)
whisper
.transcribe("path/to/audio.wav", params)
.each_segment do |segment|
segment.each_token do |token|
token => {start_time:, end_time:, text:, probability:}
st = "%05.2fs" % (start_time / 1000.0)
et = "%05.2fs" % (end_time / 1000.0)
prob = "%.1f%%" % (probability * 100)
puts "[#{st} --> #{et}] #{text} (#{prob})"
end
end
```
```
[00.00s --> 00.00s] [_BEG_] (84.2%)
[00.32s --> 00.37s] And (71.2%)
[00.37s --> 00.53s] so (98.5%)
[00.69s --> 00.85s] my (70.7%)
[00.85s --> 01.59s] fellow (99.5%)
[01.59s --> 02.10s] Americans (90.1%)
[02.85s --> 03.30s] , (28.4%)
[03.30s --> 04.14s] ask (79.8%)
[04.14s --> 04.28s] not (78.9%)
[05.03s --> 05.35s] what (93.3%)
[05.41s --> 05.74s] your (98.8%)
[05.74s --> 06.41s] country (99.6%)
[06.41s --> 06.74s] can (97.7%)
[06.74s --> 06.92s] do (99.0%)
[07.00s --> 07.00s] for (95.8%)
[07.01s --> 07.52s] you (98.5%)
[07.81s --> 08.05s] , (49.3%)
[08.19s --> 08.37s] ask (65.6%)
[08.37s --> 08.75s] what (98.8%)
[08.91s --> 09.04s] you (98.2%)
[09.04s --> 09.32s] can (96.9%)
[09.32s --> 09.38s] do (90.3%)
[09.44s --> 09.76s] for (91.8%)
[09.76s --> 09.99s] your (98.2%)
[10.02s --> 10.36s] country (99.6%)
[10.51s --> 10.99s] . (87.0%)
[11.00s --> 11.00s] [_TT_550] (7.6%)
```
### Models ###
You can see model information:
@ -342,6 +394,20 @@ whisper
.full(Whisper::Params.new, samples)
```
Custom context params
---------------------
You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`.
```ruby
context_params = Whisper::Context::Params.new(
use_gpu: false,
flash_attn: false,
# etc
)
whisper = Whisper::Context.new("base", context_params)
```
Using VAD separately from ASR
-----------------------------

View File

@ -33,7 +33,8 @@ static bool is_log_callback_finalized = false;
// High level API
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
extern void init_ruby_whisper_context(VALUE *mWhisper);
extern VALUE init_ruby_whisper_context(VALUE *mWhisper);
extern void init_ruby_whisper_context_params(VALUE *cContext);
extern void init_ruby_whisper_params(VALUE *mWhisper);
extern void init_ruby_whisper_error(VALUE *mWhisper);
extern void init_ruby_whisper_segment(VALUE *mWhisper);
@ -162,6 +163,22 @@ void Init_whisper() {
rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE));
rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST));
rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM));
rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN));
rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY));
rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN));
rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE));
rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN));
rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL));
rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN));
rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM));
rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1));
rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2));
rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3));
rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO));
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
@ -170,7 +187,8 @@ void Init_whisper() {
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
init_ruby_whisper_context(&mWhisper);
cContext = init_ruby_whisper_context(&mWhisper);
init_ruby_whisper_context_params(&cContext);
init_ruby_whisper_params(&mWhisper);
init_ruby_whisper_error(&mWhisper);
init_ruby_whisper_segment(&mWhisper);

View File

@ -16,6 +16,10 @@ typedef struct {
struct whisper_context *context;
} ruby_whisper;
typedef struct ruby_whisper_context_params {
struct whisper_context_params params;
} ruby_whisper_context_params;
typedef struct {
struct whisper_full_params params;
bool diarize;
@ -37,7 +41,7 @@ typedef struct {
typedef struct {
whisper_token_data *token_data;
const char *text;
VALUE text;
} ruby_whisper_token;
typedef struct {
@ -71,7 +75,11 @@ typedef struct parsed_samples_t {
} \
} while (0)
#define GetToken(obj, rwt) do { \
#define GetContextParams(obj, rwcp) do { \
TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \
} while (0)
#define GetToken(obj, rwt) do { \
TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \
if ((rwt)->token_data == NULL) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \

View File

@ -18,6 +18,7 @@ extern VALUE eError;
extern VALUE cModel;
extern const rb_data_type_t ruby_whisper_params_type;
extern const rb_data_type_t ruby_whisper_context_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
@ -143,16 +144,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper *rw;
VALUE whisper_model_file_path;
VALUE context_params;
struct whisper_context_params params;
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
if (NIL_P(context_params)) {
params = whisper_context_default_params();
} else {
ruby_whisper_context_params *rwcp;
GetContextParams(context_params, rwcp);
params = rwcp->params;
}
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params);
if (rw->context == NULL) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
@ -711,7 +721,7 @@ ruby_whisper_get_model(VALUE self)
return rb_whisper_model_s_new(self);
}
void
VALUE
init_ruby_whisper_context(VALUE *mWhisper)
{
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
@ -749,4 +759,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
return cContext;
}

View File

@ -0,0 +1,163 @@
#include "ruby_whisper.h"
#define NUM_PARAMS 6
#define DEF_BOOLEAN_ATTR_METHOD(name) \
static VALUE \
ruby_whisper_context_params_get_ ## name(VALUE self) { \
ruby_whisper_context_params *rwcp; \
GetContextParams(self, rwcp); \
return rwcp->params.name ? Qtrue : Qfalse; \
} \
static VALUE \
ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
ruby_whisper_context_params *rwcp; \
GetContextParams(self, rwcp); \
rwcp->params.name = RTEST(value); \
return value; \
}
#define DEF_INT_ATTR_METHOD(name) \
static VALUE \
ruby_whisper_context_params_get_ ## name(VALUE self) { \
ruby_whisper_context_params *rwcp; \
GetContextParams(self, rwcp); \
return INT2NUM(rwcp->params.name); \
} \
static VALUE \
ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
ruby_whisper_context_params *rwcp; \
GetContextParams(self, rwcp); \
rwcp->params.name = NUM2INT(value); \
return value; \
}
#define DEFINE_PARAM(param_name, nth) \
id_ ## param_name = rb_intern(#param_name); \
param_names[nth] = id_ ## param_name; \
rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \
rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1);
VALUE cContextParams;
static ID param_names[NUM_PARAMS];
static ID id_use_gpu;
static ID id_flash_attn;
static ID id_gpu_device;
static ID id_dtw_token_timestamps;
static ID id_dtw_aheads_preset;
static ID id_dtw_n_top;
static size_t
ruby_whisper_context_params_memsize(const void *p)
{
const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p;
if (!rwcp) {
return 0;
}
return sizeof(ruby_whisper_context_params);
}
const rb_data_type_t ruby_whisper_context_params_type = {
"ruby_whisper_context_params",
{0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_context_params_s_allocate(VALUE klass)
{
ruby_whisper_context_params *rwcp;
return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
}
DEF_BOOLEAN_ATTR_METHOD(use_gpu);
DEF_BOOLEAN_ATTR_METHOD(flash_attn);
DEF_INT_ATTR_METHOD(gpu_device);
DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps);
DEF_INT_ATTR_METHOD(dtw_aheads_preset);
static VALUE
ruby_whisper_context_params_get_dtw_n_top(VALUE self) {
ruby_whisper_context_params *rwcp;
GetContextParams(self, rwcp);
int dtw_n_top = rwcp->params.dtw_n_top;
return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top);
}
static VALUE
ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) {
ruby_whisper_context_params *rwcp;
GetContextParams(self, rwcp);
rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value);
return value;
}
#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_context_params_set_ ## param_name(self, value); \
continue; \
}
static VALUE
ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper_context_params *rwcp;
TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
rwcp->params = whisper_context_default_params();
VALUE kw_hash;
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return Qnil;
}
VALUE values[NUM_PARAMS] = {Qundef};
rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
ID id;
VALUE value;
for (int i = 0; i < NUM_PARAMS; i++) {
id = param_names[i];
value = values[i];
if (value == Qundef) {
continue;
}
SET_PARAM_IF_SAME(use_gpu)
SET_PARAM_IF_SAME(flash_attn)
SET_PARAM_IF_SAME(gpu_device)
SET_PARAM_IF_SAME(dtw_token_timestamps)
SET_PARAM_IF_SAME(dtw_aheads_preset)
SET_PARAM_IF_SAME(dtw_n_top)
}
return Qnil;
}
#undef SET_PARAM_IF_SAME
void
init_ruby_whisper_context_params(VALUE *cContext)
{
cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject);
rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate);
rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1);
DEFINE_PARAM(use_gpu, 0)
DEFINE_PARAM(flash_attn, 1)
DEFINE_PARAM(gpu_device, 2)
DEFINE_PARAM(dtw_token_timestamps, 3)
DEFINE_PARAM(dtw_aheads_preset, 4)
DEFINE_PARAM(dtw_n_top, 5)
}
#undef DEFINE_PARAM
#undef DEF_INT_ATTR_METHOD
#undef DEF_BOOLEAN_ATTR_METHOD
#undef NUM_PARAMS

View File

@ -24,12 +24,34 @@ ruby_whisper_token_memsize(const void *p)
if (!rwt) {
return 0;
}
return sizeof(rwt);
size_t size = sizeof(*rwt);
if (rwt->token_data) {
size += sizeof(*rwt->token_data);
}
return size;
}
static void
ruby_whisper_token_mark(void *p)
{
ruby_whisper_token *rwt = (ruby_whisper_token *)p;
rb_gc_mark(rwt->text);
}
static void
ruby_whisper_token_free(void *p)
{
ruby_whisper_token *rwt = (ruby_whisper_token *)p;
if (rwt->token_data) {
xfree(rwt->token_data);
rwt->token_data = NULL;
}
xfree(rwt);
}
static const rb_data_type_t ruby_whisper_token_type = {
"ruby_whisper_token",
{0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,},
{ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,},
0, 0,
0
};
@ -40,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass)
ruby_whisper_token *rwt;
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt);
rwt->token_data = NULL;
rwt->text = NULL;
rwt->text = Qnil;
return token;
}
VALUE
ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token)
{
whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token);
const VALUE token = ruby_whisper_token_allocate(cToken);
ruby_whisper_token *rwt;
TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt);
rwt->token_data = &token_data;
rwt->text = whisper_full_get_token_text(context, i_segment, i_token);
rwt->token_data = ALLOC(whisper_token_data);
*(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token);
rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token));
return token;
}
@ -182,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self)
{
ruby_whisper_token *rwt;
GetToken(self, rwt);
return rb_str_new2(rwt->text);
return rwt->text;
}
/*
* Start time of the token.
*

View File

@ -17,6 +17,21 @@ module Whisper
LOG_LEVEL_ERROR: Integer
LOG_LEVEL_DEBUG: Integer
LOG_LEVEL_CONT: Integer
AHEADS_NONE: Integer
AHEADS_N_TOP_MOST: Integer
AHEADS_CUSTOM: Integer
AHEADS_TINY_EN: Integer
AHEADS_TINY: Integer
AHEADS_BASE_EN: Integer
AHEADS_BASE: Integer
AHEADS_SMALL_EN: Integer
AHEADS_SMALL: Integer
AHEADS_MEDIUM_EN: Integer
AHEADS_MEDIUM: Integer
AHEADS_LARGE_V1: Integer
AHEADS_LARGE_V2: Integer
AHEADS_LARGE_V3: Integer
AHEADS_LARGE_V3_TURBO: Integer
def self.lang_max_id: () -> Integer
def self.lang_id: (string name) -> Integer
@ -120,6 +135,30 @@ module Whisper
def to_srt: () -> String
def to_webvtt: () -> String
class Params
def self.new: (
use_gpu: boolish,
flash_attn: boolish,
gpu_device: Integer,
dtw_token_timestamps: boolish,
dtw_aheads_preset: Integer,
dtw_n_top: Integer | nil,
) -> instance
def use_gpu=: (boolish) -> boolish
def use_gpu: () -> (true | false)
def flash_attn=: (boolish) -> boolish
def flash_attn: () -> (true | false)
def gpu_device=: (Integer) -> Integer
def gpu_device: () -> Integer
def dtw_token_timestamps=: (boolish) -> boolish
def dtw_token_timestamps: () -> (true | false)
def dtw_aheads_preset=: (Integer) -> Integer
def dtw_aheads_preset: () -> Integer
def dtw_n_top=: (Integer | nil) -> (Integer | nil)
def dtw_n_top: () -> (Integer | nil)
end
end
class Params

View File

@ -0,0 +1,82 @@
require_relative "helper"
class TestContextParams < TestBase
PARAM_NAMES = [
:use_gpu,
:flash_attn,
:gpu_device,
:dtw_token_timestamps,
:dtw_aheads_preset,
:dtw_n_top
]
def test_new
params = Whisper::Context::Params.new
assert_instance_of Whisper::Context::Params, params
end
def test_attributes
params = Whisper::Context::Params.new
assert_true params.use_gpu
params.use_gpu = false
assert_false params.use_gpu
assert_true params.flash_attn
params.flash_attn = false
assert_false params.flash_attn
assert_equal 0, params.gpu_device
params.gpu_device = 1
assert_equal 1, params.gpu_device
assert_false params.dtw_token_timestamps
params.dtw_token_timestamps = true
assert_true params.dtw_token_timestamps
assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset
params.dtw_aheads_preset =Whisper::AHEADS_BASE
assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset
assert_nil params.dtw_n_top
params.dtw_n_top = 6
assert_equal 6, params.dtw_n_top
params.dtw_n_top = nil
assert_nil params.dtw_n_top
end
def test_new_with_kw_args
params = Whisper::Context::Params.new(use_gpu: false)
assert_false params.use_gpu
end
def test_new_with_kw_wargs_non_existent
assert_raise ArgumentError do
Whisper::Context::Params.new(non_existent: "value")
end
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_params = Whisper::Context::Params.new
default_value = default_params.send(param)
value = if param == :dtw_n_top
6
else
case default_value
in true | false
!default_value
in Integer
default_value + 1
end
end
params = Whisper::Context::Params.new(param => value)
assert_equal value, params.send(param)
PARAM_NAMES.reject {|name| name == param}.each do |name|
expected = default_params.send(name)
actual = params.send(name)
assert_equal expected, actual
end
end
end

View File

@ -56,6 +56,17 @@ class TestToken < TestBase
@segment.each_token.collect(&:text)
end
def test_token_timestamps
params = Whisper::Params.new(token_timestamps: true)
whisper.transcribe(TestBase::AUDIO, params)
prev = -1
whisper.each_segment.first.each_token do |token|
assert token.start_time >= prev
assert token.end_time >= token.start_time
prev = token.end_time
end
end
def test_deconstruct_keys_with_nil
keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text]
expected = keys.collect {|key| [key, @token.send(key)] }.to_h