ruby : add `Whisper::Context::Params`, fix token memory management (#3647)
* Don't convert to temporary VALUE
* Define Whisper::Context::Params
* Add test for Whisper::Context::Params
* Implement Whisper::Context::Params
* Add tests for Context::Params
* Fix Whisper::Token memory management
* Add test for token_timestamps
* Make Context accept Context::Params
* Make Context::Params.new accept keyword args
* Add test for Context::Params.new with keyword args
* Add signature of Context::Params
* Add example for Whisper::Token
* Fix typos
* Revert "Don't convert to temporary VALUE"
This reverts commit dee66e7384.
* Hold Token#text as Ruby objectd
* Don't use pointer for ruby_whisper_context_params.params
* Use RUBY_DEFAULT_FREE instead of custom function
* Update bindings/ruby/README.md
Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
* Add document for Whisper::Context::Params
---------
Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
parent
aa1bc0d1a6
commit
941bdabbe4
|
|
@ -247,6 +247,58 @@ whisper.transcribe("path/to/audio.wav", params)
|
|||
|
||||
```
|
||||
|
||||
### Tokens ###
|
||||
|
||||
Each segment has tokens.
|
||||
|
||||
To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`.
|
||||
|
||||
```ruby
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new(token_timestamps: true)
|
||||
whisper
|
||||
.transcribe("path/to/audio.wav", params)
|
||||
.each_segment do |segment|
|
||||
segment.each_token do |token|
|
||||
token => {start_time:, end_time:, text:, probability:}
|
||||
st = "%05.2fs" % (start_time / 1000.0)
|
||||
et = "%05.2fs" % (end_time / 1000.0)
|
||||
prob = "%.1f%%" % (probability * 100)
|
||||
puts "[#{st} --> #{et}] #{text} (#{prob})"
|
||||
end
|
||||
end
|
||||
```
|
||||
|
||||
```
|
||||
[00.00s --> 00.00s] [_BEG_] (84.2%)
|
||||
[00.32s --> 00.37s] And (71.2%)
|
||||
[00.37s --> 00.53s] so (98.5%)
|
||||
[00.69s --> 00.85s] my (70.7%)
|
||||
[00.85s --> 01.59s] fellow (99.5%)
|
||||
[01.59s --> 02.10s] Americans (90.1%)
|
||||
[02.85s --> 03.30s] , (28.4%)
|
||||
[03.30s --> 04.14s] ask (79.8%)
|
||||
[04.14s --> 04.28s] not (78.9%)
|
||||
[05.03s --> 05.35s] what (93.3%)
|
||||
[05.41s --> 05.74s] your (98.8%)
|
||||
[05.74s --> 06.41s] country (99.6%)
|
||||
[06.41s --> 06.74s] can (97.7%)
|
||||
[06.74s --> 06.92s] do (99.0%)
|
||||
[07.00s --> 07.00s] for (95.8%)
|
||||
[07.01s --> 07.52s] you (98.5%)
|
||||
[07.81s --> 08.05s] , (49.3%)
|
||||
[08.19s --> 08.37s] ask (65.6%)
|
||||
[08.37s --> 08.75s] what (98.8%)
|
||||
[08.91s --> 09.04s] you (98.2%)
|
||||
[09.04s --> 09.32s] can (96.9%)
|
||||
[09.32s --> 09.38s] do (90.3%)
|
||||
[09.44s --> 09.76s] for (91.8%)
|
||||
[09.76s --> 09.99s] your (98.2%)
|
||||
[10.02s --> 10.36s] country (99.6%)
|
||||
[10.51s --> 10.99s] . (87.0%)
|
||||
[11.00s --> 11.00s] [_TT_550] (7.6%)
|
||||
```
|
||||
|
||||
### Models ###
|
||||
|
||||
You can see model information:
|
||||
|
|
@ -342,6 +394,20 @@ whisper
|
|||
.full(Whisper::Params.new, samples)
|
||||
```
|
||||
|
||||
Custom context params
|
||||
---------------------
|
||||
|
||||
You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`.
|
||||
|
||||
```ruby
|
||||
context_params = Whisper::Context::Params.new(
|
||||
use_gpu: false,
|
||||
flash_attn: false,
|
||||
# etc
|
||||
)
|
||||
whisper = Whisper::Context.new("base", context_params)
|
||||
```
|
||||
|
||||
Using VAD separately from ASR
|
||||
-----------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,8 @@ static bool is_log_callback_finalized = false;
|
|||
// High level API
|
||||
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
|
||||
|
||||
extern void init_ruby_whisper_context(VALUE *mWhisper);
|
||||
extern VALUE init_ruby_whisper_context(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_context_params(VALUE *cContext);
|
||||
extern void init_ruby_whisper_params(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_error(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_segment(VALUE *mWhisper);
|
||||
|
|
@ -162,6 +163,22 @@ void Init_whisper() {
|
|||
rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
|
||||
|
||||
rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE));
|
||||
rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST));
|
||||
rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM));
|
||||
rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN));
|
||||
rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY));
|
||||
rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN));
|
||||
rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE));
|
||||
rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN));
|
||||
rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL));
|
||||
rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN));
|
||||
rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM));
|
||||
rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1));
|
||||
rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2));
|
||||
rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3));
|
||||
rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO));
|
||||
|
||||
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
|
||||
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
|
||||
|
|
@ -170,7 +187,8 @@ void Init_whisper() {
|
|||
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
|
||||
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
|
||||
|
||||
init_ruby_whisper_context(&mWhisper);
|
||||
cContext = init_ruby_whisper_context(&mWhisper);
|
||||
init_ruby_whisper_context_params(&cContext);
|
||||
init_ruby_whisper_params(&mWhisper);
|
||||
init_ruby_whisper_error(&mWhisper);
|
||||
init_ruby_whisper_segment(&mWhisper);
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ typedef struct {
|
|||
struct whisper_context *context;
|
||||
} ruby_whisper;
|
||||
|
||||
typedef struct ruby_whisper_context_params {
|
||||
struct whisper_context_params params;
|
||||
} ruby_whisper_context_params;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_full_params params;
|
||||
bool diarize;
|
||||
|
|
@ -37,7 +41,7 @@ typedef struct {
|
|||
|
||||
typedef struct {
|
||||
whisper_token_data *token_data;
|
||||
const char *text;
|
||||
VALUE text;
|
||||
} ruby_whisper_token;
|
||||
|
||||
typedef struct {
|
||||
|
|
@ -71,7 +75,11 @@ typedef struct parsed_samples_t {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetToken(obj, rwt) do { \
|
||||
#define GetContextParams(obj, rwcp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \
|
||||
} while (0)
|
||||
|
||||
#define GetToken(obj, rwt) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \
|
||||
if ((rwt)->token_data == NULL) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ extern VALUE eError;
|
|||
extern VALUE cModel;
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_params_type;
|
||||
extern const rb_data_type_t ruby_whisper_context_params_type;
|
||||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
|
|
@ -143,16 +144,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
|
|||
{
|
||||
ruby_whisper *rw;
|
||||
VALUE whisper_model_file_path;
|
||||
VALUE context_params;
|
||||
struct whisper_context_params params;
|
||||
|
||||
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params);
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
|
||||
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
|
||||
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||
}
|
||||
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
|
||||
if (NIL_P(context_params)) {
|
||||
params = whisper_context_default_params();
|
||||
} else {
|
||||
ruby_whisper_context_params *rwcp;
|
||||
GetContextParams(context_params, rwcp);
|
||||
params = rwcp->params;
|
||||
}
|
||||
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params);
|
||||
if (rw->context == NULL) {
|
||||
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
|
||||
}
|
||||
|
|
@ -711,7 +721,7 @@ ruby_whisper_get_model(VALUE self)
|
|||
return rb_whisper_model_s_new(self);
|
||||
}
|
||||
|
||||
void
|
||||
VALUE
|
||||
init_ruby_whisper_context(VALUE *mWhisper)
|
||||
{
|
||||
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
|
||||
|
|
@ -749,4 +759,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
|
|||
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
|
||||
|
||||
rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
|
||||
|
||||
return cContext;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define NUM_PARAMS 6
|
||||
|
||||
#define DEF_BOOLEAN_ATTR_METHOD(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_context_params_get_ ## name(VALUE self) { \
|
||||
ruby_whisper_context_params *rwcp; \
|
||||
GetContextParams(self, rwcp); \
|
||||
return rwcp->params.name ? Qtrue : Qfalse; \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
|
||||
ruby_whisper_context_params *rwcp; \
|
||||
GetContextParams(self, rwcp); \
|
||||
rwcp->params.name = RTEST(value); \
|
||||
return value; \
|
||||
}
|
||||
|
||||
#define DEF_INT_ATTR_METHOD(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_context_params_get_ ## name(VALUE self) { \
|
||||
ruby_whisper_context_params *rwcp; \
|
||||
GetContextParams(self, rwcp); \
|
||||
return INT2NUM(rwcp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \
|
||||
ruby_whisper_context_params *rwcp; \
|
||||
GetContextParams(self, rwcp); \
|
||||
rwcp->params.name = NUM2INT(value); \
|
||||
return value; \
|
||||
}
|
||||
|
||||
#define DEFINE_PARAM(param_name, nth) \
|
||||
id_ ## param_name = rb_intern(#param_name); \
|
||||
param_names[nth] = id_ ## param_name; \
|
||||
rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \
|
||||
rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1);
|
||||
|
||||
VALUE cContextParams;
|
||||
|
||||
static ID param_names[NUM_PARAMS];
|
||||
static ID id_use_gpu;
|
||||
static ID id_flash_attn;
|
||||
static ID id_gpu_device;
|
||||
static ID id_dtw_token_timestamps;
|
||||
static ID id_dtw_aheads_preset;
|
||||
static ID id_dtw_n_top;
|
||||
|
||||
static size_t
|
||||
ruby_whisper_context_params_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p;
|
||||
if (!rwcp) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_context_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_context_params_type = {
|
||||
"ruby_whisper_context_params",
|
||||
{0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_context_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_context_params *rwcp;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
|
||||
}
|
||||
|
||||
DEF_BOOLEAN_ATTR_METHOD(use_gpu);
|
||||
DEF_BOOLEAN_ATTR_METHOD(flash_attn);
|
||||
DEF_INT_ATTR_METHOD(gpu_device);
|
||||
DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps);
|
||||
DEF_INT_ATTR_METHOD(dtw_aheads_preset);
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_context_params_get_dtw_n_top(VALUE self) {
|
||||
ruby_whisper_context_params *rwcp;
|
||||
GetContextParams(self, rwcp);
|
||||
|
||||
int dtw_n_top = rwcp->params.dtw_n_top;
|
||||
|
||||
return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) {
|
||||
ruby_whisper_context_params *rwcp;
|
||||
GetContextParams(self, rwcp);
|
||||
|
||||
rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
#define SET_PARAM_IF_SAME(param_name) \
|
||||
if (id == id_ ## param_name) { \
|
||||
ruby_whisper_context_params_set_ ## param_name(self, value); \
|
||||
continue; \
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
ruby_whisper_context_params *rwcp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp);
|
||||
rwcp->params = whisper_context_default_params();
|
||||
|
||||
VALUE kw_hash;
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
VALUE values[NUM_PARAMS] = {Qundef};
|
||||
rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
|
||||
|
||||
ID id;
|
||||
VALUE value;
|
||||
for (int i = 0; i < NUM_PARAMS; i++) {
|
||||
id = param_names[i];
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
SET_PARAM_IF_SAME(use_gpu)
|
||||
SET_PARAM_IF_SAME(flash_attn)
|
||||
SET_PARAM_IF_SAME(gpu_device)
|
||||
SET_PARAM_IF_SAME(dtw_token_timestamps)
|
||||
SET_PARAM_IF_SAME(dtw_aheads_preset)
|
||||
SET_PARAM_IF_SAME(dtw_n_top)
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
#undef SET_PARAM_IF_SAME
|
||||
|
||||
void
|
||||
init_ruby_whisper_context_params(VALUE *cContext)
|
||||
{
|
||||
cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate);
|
||||
rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1);
|
||||
|
||||
DEFINE_PARAM(use_gpu, 0)
|
||||
DEFINE_PARAM(flash_attn, 1)
|
||||
DEFINE_PARAM(gpu_device, 2)
|
||||
DEFINE_PARAM(dtw_token_timestamps, 3)
|
||||
DEFINE_PARAM(dtw_aheads_preset, 4)
|
||||
DEFINE_PARAM(dtw_n_top, 5)
|
||||
}
|
||||
|
||||
#undef DEFINE_PARAM
|
||||
#undef DEF_INT_ATTR_METHOD
|
||||
#undef DEF_BOOLEAN_ATTR_METHOD
|
||||
#undef NUM_PARAMS
|
||||
|
|
@ -24,12 +24,34 @@ ruby_whisper_token_memsize(const void *p)
|
|||
if (!rwt) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(rwt);
|
||||
size_t size = sizeof(*rwt);
|
||||
if (rwt->token_data) {
|
||||
size += sizeof(*rwt->token_data);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_token_mark(void *p)
|
||||
{
|
||||
ruby_whisper_token *rwt = (ruby_whisper_token *)p;
|
||||
rb_gc_mark(rwt->text);
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_token_free(void *p)
|
||||
{
|
||||
ruby_whisper_token *rwt = (ruby_whisper_token *)p;
|
||||
if (rwt->token_data) {
|
||||
xfree(rwt->token_data);
|
||||
rwt->token_data = NULL;
|
||||
}
|
||||
xfree(rwt);
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_token_type = {
|
||||
"ruby_whisper_token",
|
||||
{0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,},
|
||||
{ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
|
@ -40,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass)
|
|||
ruby_whisper_token *rwt;
|
||||
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt);
|
||||
rwt->token_data = NULL;
|
||||
rwt->text = NULL;
|
||||
rwt->text = Qnil;
|
||||
return token;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token)
|
||||
{
|
||||
whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token);
|
||||
const VALUE token = ruby_whisper_token_allocate(cToken);
|
||||
ruby_whisper_token *rwt;
|
||||
TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt);
|
||||
rwt->token_data = &token_data;
|
||||
rwt->text = whisper_full_get_token_text(context, i_segment, i_token);
|
||||
rwt->token_data = ALLOC(whisper_token_data);
|
||||
*(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token);
|
||||
rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token));
|
||||
return token;
|
||||
}
|
||||
|
||||
|
|
@ -182,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self)
|
|||
{
|
||||
ruby_whisper_token *rwt;
|
||||
GetToken(self, rwt);
|
||||
return rb_str_new2(rwt->text);
|
||||
return rwt->text;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Start time of the token.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -17,6 +17,21 @@ module Whisper
|
|||
LOG_LEVEL_ERROR: Integer
|
||||
LOG_LEVEL_DEBUG: Integer
|
||||
LOG_LEVEL_CONT: Integer
|
||||
AHEADS_NONE: Integer
|
||||
AHEADS_N_TOP_MOST: Integer
|
||||
AHEADS_CUSTOM: Integer
|
||||
AHEADS_TINY_EN: Integer
|
||||
AHEADS_TINY: Integer
|
||||
AHEADS_BASE_EN: Integer
|
||||
AHEADS_BASE: Integer
|
||||
AHEADS_SMALL_EN: Integer
|
||||
AHEADS_SMALL: Integer
|
||||
AHEADS_MEDIUM_EN: Integer
|
||||
AHEADS_MEDIUM: Integer
|
||||
AHEADS_LARGE_V1: Integer
|
||||
AHEADS_LARGE_V2: Integer
|
||||
AHEADS_LARGE_V3: Integer
|
||||
AHEADS_LARGE_V3_TURBO: Integer
|
||||
|
||||
def self.lang_max_id: () -> Integer
|
||||
def self.lang_id: (string name) -> Integer
|
||||
|
|
@ -120,6 +135,30 @@ module Whisper
|
|||
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
use_gpu: boolish,
|
||||
flash_attn: boolish,
|
||||
gpu_device: Integer,
|
||||
dtw_token_timestamps: boolish,
|
||||
dtw_aheads_preset: Integer,
|
||||
dtw_n_top: Integer | nil,
|
||||
) -> instance
|
||||
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
def use_gpu: () -> (true | false)
|
||||
def flash_attn=: (boolish) -> boolish
|
||||
def flash_attn: () -> (true | false)
|
||||
def gpu_device=: (Integer) -> Integer
|
||||
def gpu_device: () -> Integer
|
||||
def dtw_token_timestamps=: (boolish) -> boolish
|
||||
def dtw_token_timestamps: () -> (true | false)
|
||||
def dtw_aheads_preset=: (Integer) -> Integer
|
||||
def dtw_aheads_preset: () -> Integer
|
||||
def dtw_n_top=: (Integer | nil) -> (Integer | nil)
|
||||
def dtw_n_top: () -> (Integer | nil)
|
||||
end
|
||||
end
|
||||
|
||||
class Params
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestContextParams < TestBase
|
||||
PARAM_NAMES = [
|
||||
:use_gpu,
|
||||
:flash_attn,
|
||||
:gpu_device,
|
||||
:dtw_token_timestamps,
|
||||
:dtw_aheads_preset,
|
||||
:dtw_n_top
|
||||
]
|
||||
|
||||
def test_new
|
||||
params = Whisper::Context::Params.new
|
||||
assert_instance_of Whisper::Context::Params, params
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
params = Whisper::Context::Params.new
|
||||
|
||||
assert_true params.use_gpu
|
||||
params.use_gpu = false
|
||||
assert_false params.use_gpu
|
||||
|
||||
assert_true params.flash_attn
|
||||
params.flash_attn = false
|
||||
assert_false params.flash_attn
|
||||
|
||||
assert_equal 0, params.gpu_device
|
||||
params.gpu_device = 1
|
||||
assert_equal 1, params.gpu_device
|
||||
|
||||
assert_false params.dtw_token_timestamps
|
||||
params.dtw_token_timestamps = true
|
||||
assert_true params.dtw_token_timestamps
|
||||
|
||||
assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset
|
||||
params.dtw_aheads_preset =Whisper::AHEADS_BASE
|
||||
assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset
|
||||
|
||||
assert_nil params.dtw_n_top
|
||||
params.dtw_n_top = 6
|
||||
assert_equal 6, params.dtw_n_top
|
||||
params.dtw_n_top = nil
|
||||
assert_nil params.dtw_n_top
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Whisper::Context::Params.new(use_gpu: false)
|
||||
assert_false params.use_gpu
|
||||
end
|
||||
|
||||
def test_new_with_kw_wargs_non_existent
|
||||
assert_raise ArgumentError do
|
||||
Whisper::Context::Params.new(non_existent: "value")
|
||||
end
|
||||
end
|
||||
|
||||
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
|
||||
def test_new_with_kw_args_default_values(param)
|
||||
default_params = Whisper::Context::Params.new
|
||||
default_value = default_params.send(param)
|
||||
value = if param == :dtw_n_top
|
||||
6
|
||||
else
|
||||
case default_value
|
||||
in true | false
|
||||
!default_value
|
||||
in Integer
|
||||
default_value + 1
|
||||
end
|
||||
end
|
||||
params = Whisper::Context::Params.new(param => value)
|
||||
assert_equal value, params.send(param)
|
||||
|
||||
PARAM_NAMES.reject {|name| name == param}.each do |name|
|
||||
expected = default_params.send(name)
|
||||
actual = params.send(name)
|
||||
assert_equal expected, actual
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -56,6 +56,17 @@ class TestToken < TestBase
|
|||
@segment.each_token.collect(&:text)
|
||||
end
|
||||
|
||||
def test_token_timestamps
|
||||
params = Whisper::Params.new(token_timestamps: true)
|
||||
whisper.transcribe(TestBase::AUDIO, params)
|
||||
prev = -1
|
||||
whisper.each_segment.first.each_token do |token|
|
||||
assert token.start_time >= prev
|
||||
assert token.end_time >= token.start_time
|
||||
prev = token.end_time
|
||||
end
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_nil
|
||||
keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text]
|
||||
expected = keys.collect {|key| [key, @token.send(key)] }.to_h
|
||||
|
|
|
|||
Loading…
Reference in New Issue