Add Parakeet::Context#each_segment

This commit is contained in:
Kitaiti Makoto 2026-05-15 01:38:08 +09:00
parent 6bcc09b911
commit bb096e54ac
6 changed files with 103 additions and 3 deletions

View File

@ -9,7 +9,9 @@ VALUE cVADContext;
VALUE cVADParams;
VALUE cVADSegments;
VALUE cVADSegment;
VALUE cParakeetContext;
VALUE cParakeetParams;
VALUE cParakeetSegment;
VALUE eError;
VALUE cSegment;

View File

@ -98,6 +98,11 @@ typedef struct {
struct parakeet_context *context;
} ruby_whisper_parakeet_context;
typedef struct {
VALUE context;
int index;
} ruby_whisper_parakeet_segment;
#define GetContext(obj, rw) do { \
TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
if ((rw)->context == NULL) { \

View File

@ -1,9 +1,14 @@
#include "ruby_whisper.h"
extern ID id_to_s;
extern ID id___method__;
extern ID id_to_enum;
extern VALUE cParakeetContext;
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
static void
ruby_whisper_parakeet_context_free(void *p)
@ -65,13 +70,33 @@ ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
return Qnil;
}
static VALUE
ruby_whisper_parakeet_context_each_segment(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
const int n_segments = parakeet_full_n_segments(rwpc->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
}
return self;
}
void
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
{
VALUE cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
}

View File

@ -35,6 +35,8 @@ enum {
#define VAL_TO_BOOL(v) (RTEST(v))
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
extern VALUE cParakeetParams;
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
@ -208,7 +210,7 @@ ITERATE_CALLBACK_PARAMS(INIT_CONTAINER)
void
init_ruby_whisper_parakeet_params(VALUE *mParakeet)
{
VALUE cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate);
rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1);

View File

@ -1,7 +1,59 @@
#include "ruby_whisper.h"
extern VALUE cParakeetSegment;
static void
rb_whisper_parakeet_segment_mark(void *p)
{
ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p;
rb_gc_mark(rwps->context);
}
static size_t
ruby_whisper_parakeet_segment_memsize(const void *p)
{
const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p;
if (!rwps) {
return 0;
}
size_t size = sizeof(*rwps);
if (rwps->index) {
size += sizeof(rwps->index);
}
return size;
}
static const rb_data_type_t ruby_whisper_parakeet_segment_type = {
"ruby_whisper_segment",
{rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_segment_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_segment *rwps;
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
}
VALUE
ruby_whisper_parakeet_segment_init(VALUE context, int index)
{
ruby_whisper_parakeet_segment *rwps;
const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment);
TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
rwps->context = context;
rwps->index = index;
return segment;
}
void
init_ruby_whisper_parakeet_segment(VALUE *mParakeet)
{
rb_define_class_under(*mParakeet, "Segment", rb_cObject);
cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject);
rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate);
}

View File

@ -0,0 +1,14 @@
require_relative "helper"
class TestParakeetSegment < TestBase
def setup
@parakeet = Parakeet::Context.new(File.join(__dir__, "../../../models/parakeet-tdt-0.6b-v3.bin"))
@parakeet.transcribe AUDIO, Parakeet::Params.new
end
def test_segment
@parakeet.each_segment do |segment|
assert_instance_of Parakeet::Segment, segment
end
end
end