whisper.cpp/bindings/ruby/ext/ruby_whisper_parakeet_token.c

189 lines
5.5 KiB
C

#include "ruby_whisper.h"
#define ITERATE_MEMBERS(ITERATOR) \
ITERATOR(id, id, id, id, INT) \
ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \
ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \
ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \
ITERATOR(probability, probability, p, p, FLOAT) \
ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \
ITERATOR(start_time, start_time, start_time, t0, TIME) \
ITERATOR(end_time, end_time, end_time, t1, TIME) \
ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL)
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(text, text, text, text, STRING)
enum {
#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name,
ITERATE_MEMBERS(DEF_IDX)
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS,
};
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_FROM_FLOAT(v) (DBL2NUM(v))
#define VAL_FROM_TIME(v) (LONG2NUM(v * 10))
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
#define VAL_FROM_STRING(v) (rb_str_new2(v))
#define READER(type) VAL_FROM_##type
#define MEMBER_NAME(name) name
#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \
}
#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return rwpt->p_name; \
}
VALUE cParakeetToken;
#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key;
ITERATE_MEMBERS(DEC_ATTR_SYMS)
ITERATE_ATTRS(DEC_ATTR_SYMS)
static void
ruby_whisper_parakeet_token_mark(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
rb_gc_mark(rwpt->text);
}
static void
ruby_whisper_parakeet_token_free(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (rwpt->token_data) {
xfree(rwpt->token_data);
rwpt->token_data = NULL;
}
xfree(rwpt);
}
static size_t
ruby_whisper_parakeet_token_memsize(const void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (!rwpt) {
return 0;
}
size_t size = sizeof(*rwpt);
if (rwpt->token_data) {
size += sizeof(*rwpt->token_data);
}
return size;
}
static const rb_data_type_t ruby_whisper_parakeet_token_type = {
"ruby_whisper_parakeet_token",
{ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize},
0, 0,
0,
};
static VALUE
ruby_whisper_parakeet_token_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_token *rwpt;
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = NULL;
rwpt->text = Qnil;
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
{
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
ruby_whisper_parakeet_token *rwpt;
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = ALLOC(parakeet_token_data);
*rwpt->token_data = *token_data;
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
{
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
}
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
// Define #text using parakeet_token_to_str or parakeet_token_to_text
ITERATE_ATTRS(DEF_ATTR)
static VALUE
ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_parakeet_token *rwpt;
GetParakeetToken(self, rwpt);
VALUE hash = rb_hash_new();
long n_keys = 0;
if (NIL_P(keys)) {
VALUE attrs[] = {
#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key,
ITERATE_MEMBERS(LIST_SYMS)
ITERATE_ATTRS(LIST_SYMS)
};
keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs);
n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) {
return hash;
}
}
for (long i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \
if (key == sym_##s_key) { \
rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \
}
ITERATE_MEMBERS(CHECK_AND_SET_KEY)
ITERATE_ATTRS(CHECK_AND_SET_KEY)
}
return hash;
}
void
init_ruby_whisper_parakeet_token(VALUE *mParakeet)
{
cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject);
rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate);
#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \
sym_##s_key = ID2SYM(rb_intern(#s_key)); \
rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0);
ITERATE_MEMBERS(REGISTER_ATTR)
ITERATE_ATTRS(REGISTER_ATTR)
rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1);
}