Compare commits
6 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
f049fff95a | |
|
|
200b119790 | |
|
|
86c40c3bd6 | |
|
|
0d14756929 | |
|
|
9efddafb91 | |
|
|
3805e602d3 |
|
|
@ -27,6 +27,6 @@ jobs:
|
|||
steps:
|
||||
- uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0
|
||||
with:
|
||||
ruby-version: '3.2'
|
||||
ruby-version: '3.3'
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6
|
||||
- run: rake test
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ jobs:
|
|||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git ccache
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
make
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ on:
|
|||
type: string
|
||||
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
|
|
@ -117,9 +115,11 @@ jobs:
|
|||
run: |
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }}
|
||||
${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Pack artifacts
|
||||
|
|
@ -175,7 +175,7 @@ jobs:
|
|||
-DBUILD_SHARED_LIBS=ON
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
-DGGML_NATIVE=OFF
|
||||
-DGGML_BMI2=OFF
|
||||
${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
|
@ -289,6 +289,8 @@ jobs:
|
|||
-DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib"
|
||||
-DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include"
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
-DGGML_NATIVE=OFF
|
||||
${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
|
|
@ -492,7 +494,10 @@ jobs:
|
|||
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
|
||||
-DSDL2_DIR="%SDL2_DIR%" ^
|
||||
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^
|
||||
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%"
|
||||
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^
|
||||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_CPU_ALL_VARIANTS=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
|
||||
project("whisper.cpp" C CXX)
|
||||
project("whisper.cpp" VERSION 1.8.7)
|
||||
project("whisper.cpp" VERSION 1.9.1)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(SOVERSION 1)
|
||||
|
|
@ -19,6 +19,7 @@ endif()
|
|||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(WHISPER_STANDALONE ON)
|
||||
|
|
@ -180,12 +181,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location
|
|||
get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
|
||||
|
||||
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
|
||||
|
||||
install(TARGETS whisper LIBRARY PUBLIC_HEADER)
|
||||
|
||||
target_compile_definitions(whisper PRIVATE
|
||||
WHISPER_VERSION="${PROJECT_VERSION}"
|
||||
)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h)
|
||||
install(TARGETS parakeet LIBRARY PUBLIC_HEADER)
|
||||
|
||||
target_compile_definitions(parakeet PRIVATE
|
||||
PARAKEET_VERSION="${PROJECT_VERSION}"
|
||||
)
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
|
||||
|
|
@ -211,6 +220,35 @@ configure_file(cmake/whisper.pc.in
|
|||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc"
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||
|
||||
set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet
|
||||
PATH_VARS
|
||||
PARAKEET_INCLUDE_INSTALL_DIR
|
||||
PARAKEET_LIB_INSTALL_DIR
|
||||
PARAKEET_BIN_INSTALL_DIR)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
|
||||
VERSION ${WHISPER_INSTALL_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet)
|
||||
|
||||
configure_file(cmake/parakeet.pc.in
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
|
||||
@ONLY)
|
||||
|
||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||
|
||||
#
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
[](https://conan.io/center/whisper-cpp)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
|
||||
Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.8.7",
|
||||
"version": "1.9.1",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
|
|
|||
|
|
@ -396,6 +396,37 @@ whisper
|
|||
.full(Whisper::Params.new, samples)
|
||||
```
|
||||
|
||||
### Parakeet ###
|
||||
|
||||
whispercpp gem now supports NVIDIA's ASR model Parakeet.
|
||||
|
||||
If you want to use Parakeet instead of Whisper, the API should feel familiar.
|
||||
In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way.
|
||||
|
||||
```ruby
|
||||
require "whisper"
|
||||
|
||||
# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem.
|
||||
Parakeet = Whisper::Parakeet
|
||||
|
||||
parakeet = Parakeet::Context.new("path/to/model")
|
||||
|
||||
params = Parakeet::Params.new(
|
||||
no_context: true
|
||||
)
|
||||
|
||||
parakeet
|
||||
.transcribe("path/to/audio.wav", params)
|
||||
.each_segment do |segment|
|
||||
puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}"
|
||||
end
|
||||
```
|
||||
|
||||
The main differences are:
|
||||
|
||||
* Namespace is `Whisper::Parakeet`.
|
||||
* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks.
|
||||
|
||||
Custom context params
|
||||
---------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,21 @@ else
|
|||
end
|
||||
end
|
||||
|
||||
TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin"
|
||||
TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin"))
|
||||
TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d")
|
||||
directory TEST_PARAKEET_MODEL_DIR
|
||||
if File.exist? TEST_PARAKEET_MODEL_SRC
|
||||
file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t|
|
||||
symlink t.source, t.name
|
||||
end
|
||||
else
|
||||
require "open-uri"
|
||||
file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t|
|
||||
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read
|
||||
end
|
||||
end
|
||||
|
||||
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
|
||||
chdir "test/jfk_reader" do
|
||||
|
|
@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
|
|||
end
|
||||
CLEAN.include TEST_MEMORY_VIEW
|
||||
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,6 @@ create_makefile "whisper" do |conf|
|
|||
#{libs}: cmake-targets
|
||||
cmake-targets:
|
||||
#{"\t"}"#{cmake}" -S sources -B build #{options}
|
||||
#{"\t"}"#{cmake}" --build build --config Release --target common whisper
|
||||
#{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet
|
||||
EOF
|
||||
end
|
||||
|
|
|
|||
|
|
@ -1,19 +1,29 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE mLogSettable;
|
||||
VALUE mVAD;
|
||||
VALUE mParakeet;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE cVADContext;
|
||||
VALUE cVADParams;
|
||||
VALUE cVADSegments;
|
||||
VALUE cVADSegment;
|
||||
VALUE cParakeetContext;
|
||||
VALUE cParakeetContextParams;
|
||||
VALUE cParakeetParams;
|
||||
VALUE cParakeetSegment;
|
||||
VALUE cParakeetModel;
|
||||
VALUE eError;
|
||||
|
||||
VALUE cSegment;
|
||||
VALUE cToken;
|
||||
VALUE cModel;
|
||||
|
||||
VALUE mOutputContext;
|
||||
VALUE mOutputSegment;
|
||||
|
||||
ID id_to_s;
|
||||
ID id_call;
|
||||
ID id___method__;
|
||||
|
|
@ -27,9 +37,11 @@ ID id_pre_converted_models;
|
|||
ID id_coreml_compiled_models;
|
||||
ID id_cache;
|
||||
ID id_n_processors;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
static bool is_ruby_log_callback_present = false;
|
||||
ID id_extended;
|
||||
ID id_start_log_callback_thread;
|
||||
ID id_log_callback_thread;
|
||||
ID id_alive_p;
|
||||
ID id_join;
|
||||
|
||||
// High level API
|
||||
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
|
||||
|
|
@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD);
|
|||
extern void init_ruby_whisper_vad_context(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_vad_segment(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_vad_segments(VALUE *mVAD);
|
||||
extern void init_ruby_whisper_parakeet(VALUE *mWhisper);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
static ruby_whisper_log_queue whisper_log_queue;
|
||||
|
||||
LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set)
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_max_id -> Integer
|
||||
|
|
@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) {
|
|||
return rb_str_new2(whisper_print_system_info());
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
|
||||
is_log_callback_finalized = true;
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int level;
|
||||
const char * buffer;
|
||||
} call_log_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_log_callbacks(void *v_args) {
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (NIL_P(log_callback)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
call_log_callbacks_args *args = (call_log_callbacks_args *)v_args;
|
||||
VALUE user_data = rb_iv_get(mWhisper, "user_data");
|
||||
rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
|
||||
if (is_log_callback_finalized) {
|
||||
return;
|
||||
}
|
||||
if (!is_ruby_log_callback_present) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_log_callbacks_args args = {
|
||||
level,
|
||||
buffer,
|
||||
};
|
||||
if (ruby_thread_has_gvl_p()) {
|
||||
call_log_callbacks((void *)&args);
|
||||
} else {
|
||||
rb_thread_call_with_gvl(call_log_callbacks, (void *)&args);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
|
||||
*/
|
||||
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
|
||||
VALUE old_callback = rb_iv_get(self, "log_callback");
|
||||
if (!NIL_P(old_callback)) {
|
||||
rb_undefine_finalizer(old_callback);
|
||||
}
|
||||
|
||||
rb_iv_set(self, "log_callback", log_callback);
|
||||
rb_iv_set(self, "user_data", user_data);
|
||||
|
||||
if (!NIL_P(log_callback)) {
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
}
|
||||
|
||||
if (NIL_P(log_callback)) {
|
||||
whisper_log_set(NULL, NULL);
|
||||
is_ruby_log_callback_present = false;
|
||||
} else {
|
||||
whisper_log_set(ruby_whisper_log_callback, NULL);
|
||||
is_ruby_log_callback_present = true;
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
id_to_s = rb_intern("to_s");
|
||||
id_call = rb_intern("call");
|
||||
|
|
@ -189,9 +133,19 @@ void Init_whisper() {
|
|||
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
|
||||
id_cache = rb_intern("cache");
|
||||
id_n_processors = rb_intern("n_processors");
|
||||
id_extended = rb_intern("extended");
|
||||
id_start_log_callback_thread = rb_intern("start_log_callback_thread");
|
||||
id_log_callback_thread = rb_intern("@log_callback_thread");
|
||||
id_alive_p = rb_intern("alive?");
|
||||
id_join = rb_intern("join");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
rb_require("whisper/log_settable");
|
||||
mLogSettable = rb_path2class("Whisper::LogSettable");
|
||||
mVAD = rb_define_module_under(mWhisper, "VAD");
|
||||
rb_require("whisper/output");
|
||||
mOutputContext = rb_path2class("Whisper::Output::Context");
|
||||
mOutputSegment = rb_path2class("Whisper::Output::Segment");
|
||||
|
||||
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
|
||||
|
|
@ -222,8 +176,8 @@ void Init_whisper() {
|
|||
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
|
||||
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
|
||||
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
|
||||
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
|
||||
|
||||
LOG_SETTABLE_INIT(whisper_log_queue, mWhisper)
|
||||
|
||||
cContext = init_ruby_whisper_context(&mWhisper);
|
||||
init_ruby_whisper_context_params(&cContext);
|
||||
|
|
@ -236,8 +190,10 @@ void Init_whisper() {
|
|||
init_ruby_whisper_vad_segment(&mVAD);
|
||||
init_ruby_whisper_vad_segments(&mVAD);
|
||||
init_ruby_whisper_vad_context(&mVAD);
|
||||
init_ruby_whisper_parakeet(&mWhisper);
|
||||
|
||||
rb_require("whisper/context");
|
||||
rb_require("whisper/segment");
|
||||
rb_require("whisper/model/uri");
|
||||
|
||||
rb_include_module(cContext, mOutputContext);
|
||||
rb_include_module(cSegment, mOutputSegment);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,8 +5,12 @@
|
|||
#include <ruby/version.h>
|
||||
#include <ruby/util.h>
|
||||
#include <ruby/thread.h>
|
||||
#include <ruby/thread_native.h>
|
||||
#include <ruby/atomic.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "whisper.h"
|
||||
#include "parakeet.h"
|
||||
#include "ruby_whisper_log_settable.h"
|
||||
|
||||
#if RUBY_API_VERSION_MAJOR < 4
|
||||
// Exists but not declared as public API
|
||||
|
|
@ -20,13 +24,28 @@ typedef struct {
|
|||
VALUE callbacks;
|
||||
} ruby_whisper_callback_container;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE user_data;
|
||||
VALUE callback;
|
||||
VALUE callbacks;
|
||||
bool is_interrupted;
|
||||
} ruby_whisper_abort_callback_container;
|
||||
typedef struct ruby_whisper_abort_callback_user_data {
|
||||
volatile rb_atomic_t is_interrupted;
|
||||
ruby_whisper_callback_container *callback_container;
|
||||
} ruby_whisper_abort_callback_user_data;
|
||||
|
||||
typedef struct ruby_whisper_log {
|
||||
enum ggml_log_level level;
|
||||
char *text;
|
||||
size_t length;
|
||||
size_t capacity;
|
||||
} ruby_whisper_log;
|
||||
|
||||
typedef struct ruby_whisper_log_queue {
|
||||
rb_nativethread_lock_t lock;
|
||||
rb_nativethread_cond_t cond;
|
||||
bool is_open;
|
||||
|
||||
size_t head;
|
||||
size_t tail;
|
||||
size_t size;
|
||||
ruby_whisper_log *logs;
|
||||
} ruby_whisper_log_queue;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_context *context;
|
||||
|
|
@ -42,7 +61,7 @@ typedef struct {
|
|||
ruby_whisper_callback_container *new_segment_callback_container;
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
VALUE vad_params;
|
||||
} ruby_whisper_params;
|
||||
|
||||
|
|
@ -84,6 +103,63 @@ typedef struct parsed_samples_t {
|
|||
bool memview_exported;
|
||||
} parsed_samples_t;
|
||||
|
||||
typedef struct {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} ruby_whisper_full_args;
|
||||
|
||||
typedef struct ruby_whisper_full_parallel_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
} ruby_whisper_full_parallel_args;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_full_params params;
|
||||
ruby_whisper_callback_container *new_segment_callback_container;
|
||||
ruby_whisper_callback_container *new_token_callback_container;
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
} ruby_whisper_parakeet_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context_params params;
|
||||
} ruby_whisper_parakeet_context_params;
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context *context;
|
||||
} ruby_whisper_parakeet_context;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
int index;
|
||||
} ruby_whisper_parakeet_segment;
|
||||
|
||||
typedef struct {
|
||||
parakeet_token_data *token_data;
|
||||
VALUE text;
|
||||
} ruby_whisper_parakeet_token;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
} ruby_whisper_parakeet_model;
|
||||
|
||||
extern ID id_extended;
|
||||
extern ID id_log_callback_thread;
|
||||
extern ID id_start_log_callback_thread;
|
||||
extern ID id_alive_p;
|
||||
extern ID id_join;
|
||||
extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue);
|
||||
extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text);
|
||||
extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue);
|
||||
|
||||
#define GetContext(obj, rw) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
|
||||
if ((rw)->context == NULL) { \
|
||||
|
|
@ -120,4 +196,47 @@ typedef struct parsed_samples_t {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContextParams(obj, rwpcp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetContext(obj, rwpc) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \
|
||||
if ((rwpc)->context == NULL) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetParams(obj, rwpp) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \
|
||||
if (!(rwpp)->new_segment_callback_container || \
|
||||
!(rwpp)->new_token_callback_container || \
|
||||
!(rwpp)->progress_callback_container || \
|
||||
!(rwpp)->encoder_begin_callback_container || \
|
||||
!(rwpp)->abort_callback_container) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetSegment(obj, rwps) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \
|
||||
if (!(rwps)->context) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetToken(obj, rwpt) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \
|
||||
if (!(rwpt)->token_data) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define GetParakeetModel(obj, rwpm) do { \
|
||||
TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \
|
||||
if (NIL_P((rwpm)->context)) { \
|
||||
rb_raise(rb_eRuntimeError, "Not initialized"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
|
|||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
|
||||
|
||||
ID transcribe_option_names[1];
|
||||
|
||||
|
|
@ -38,21 +38,6 @@ typedef struct fill_samples_args {
|
|||
int n_samples;
|
||||
} fill_samples_args;
|
||||
|
||||
typedef struct full_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
} full_args;
|
||||
|
||||
typedef struct full_parallel_args {
|
||||
VALUE *context;
|
||||
VALUE *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
} full_parallel_args;
|
||||
|
||||
typedef struct full_without_gvl_args {
|
||||
struct whisper_context *context;
|
||||
struct whisper_full_params *params;
|
||||
|
|
@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args {
|
|||
} full_parallel_without_gvl_args;
|
||||
|
||||
typedef struct full_ubf_args {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
|
||||
} full_ubf_args;
|
||||
|
||||
static void
|
||||
|
|
@ -379,7 +364,7 @@ fill_samples(VALUE rb_args)
|
|||
return Qnil;
|
||||
}
|
||||
|
||||
struct parsed_samples_t
|
||||
parsed_samples_t
|
||||
parse_samples(VALUE *samples, VALUE *n_samples)
|
||||
{
|
||||
bool memview_available = rb_memory_view_available_p(*samples);
|
||||
|
|
@ -480,20 +465,24 @@ full_ubf(void *rb_args)
|
|||
{
|
||||
full_ubf_args *args = (full_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
VALUE
|
||||
full_body(VALUE rb_args)
|
||||
{
|
||||
full_args *args = (full_args *)rb_args;
|
||||
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, 1);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
|
||||
|
||||
struct full_without_gvl_args full_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -503,7 +492,7 @@ full_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_without_gvl_args.result);
|
||||
|
|
@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
|||
VALUE n_samples = argc == 2 ? Qnil : argv[2];
|
||||
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
full_args args = {
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
|
|
@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
VALUE
|
||||
full_parallel_body(VALUE rb_args)
|
||||
{
|
||||
full_parallel_args *args = (full_parallel_args *)rb_args;
|
||||
ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args;
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
GetContext(*args->context, rw);
|
||||
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
|
||||
prepare_transcription(rwp, args->context, args->n_processors);
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
|
||||
|
||||
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
|
||||
rw->context,
|
||||
|
|
@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args)
|
|||
0,
|
||||
};
|
||||
full_ubf_args full_ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
|
||||
return INT2NUM(full_parallel_without_gvl_args.result);
|
||||
|
|
@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
|||
break;
|
||||
}
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
const full_parallel_args args = {
|
||||
const ruby_whisper_full_parallel_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define LOG_QUEUE_CAPACITY 256
|
||||
#define LOG_DEFAULT_CAPACITY 1024
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_initialize(&log_queue->lock);
|
||||
rb_native_cond_initialize(&log_queue->cond);
|
||||
log_queue->head = 0;
|
||||
log_queue->tail = 0;
|
||||
log_queue->size = 0;
|
||||
log_queue->is_open = true;
|
||||
log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY);
|
||||
for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) {
|
||||
// we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL
|
||||
// this doesn't be freed because log queue lives until the end of process
|
||||
char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY);
|
||||
if (!slot) {
|
||||
rb_raise(rb_eRuntimeError, "Could not allocate memory for log text");
|
||||
}
|
||||
ruby_whisper_log log = {
|
||||
0,
|
||||
slot,
|
||||
0,
|
||||
LOG_QUEUE_CAPACITY,
|
||||
};
|
||||
log_queue->logs[i] = log;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
log_queue->is_open = true;
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
log_queue->is_open = false;
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
static size_t
|
||||
calc_enough_cap(size_t len)
|
||||
{
|
||||
size_t quot = len / LOG_DEFAULT_CAPACITY;
|
||||
size_t rem = len % LOG_DEFAULT_CAPACITY;
|
||||
|
||||
return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY;
|
||||
}
|
||||
|
||||
void
|
||||
ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text)
|
||||
{
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
if (!log_queue->is_open) {
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t len = strlen(text);
|
||||
ruby_whisper_log *log = &log_queue->logs[log_queue->head];
|
||||
if (len > log->capacity) {
|
||||
size_t new_cap = calc_enough_cap(len);
|
||||
// we cannot call Ruby API like REALLOC_N because this function is called without GVL
|
||||
char *slot = realloc(log->text, new_cap);
|
||||
if (!slot) {
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return;
|
||||
}
|
||||
log->text = slot;
|
||||
log->capacity = new_cap;
|
||||
}
|
||||
// we cannot call Ruby API like MEMCPY because this function is called without GVL
|
||||
memcpy(log->text, text, sizeof(char) * len);
|
||||
log->length = len;
|
||||
log->level = level;
|
||||
log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY;
|
||||
bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY;
|
||||
log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1;
|
||||
if (is_full) {
|
||||
log_queue->tail = log_queue->head;
|
||||
}
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
}
|
||||
|
||||
static void*
|
||||
ruby_whisper_log_queue_wait(void *args)
|
||||
{
|
||||
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
|
||||
|
||||
rb_native_cond_wait(&log_queue->cond, &log_queue->lock);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_log_queue_wait_ubf(void *args)
|
||||
{
|
||||
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
|
||||
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
enum ggml_log_level level;
|
||||
size_t length;
|
||||
char *text;
|
||||
} log_snapshot;
|
||||
|
||||
VALUE
|
||||
ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue)
|
||||
{
|
||||
log_snapshot logs[LOG_QUEUE_CAPACITY];
|
||||
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
|
||||
while (log_queue->size == 0 && log_queue->is_open) {
|
||||
rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue);
|
||||
rb_nativethread_lock_lock(&log_queue->lock);
|
||||
}
|
||||
|
||||
if (log_queue->size == 0 && !log_queue->is_open) {
|
||||
rb_native_cond_broadcast(&log_queue->cond);
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
size_t size = log_queue->size;
|
||||
ruby_whisper_log *log;
|
||||
size_t i;
|
||||
for (i = 0; i < size; i++) {
|
||||
log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY];
|
||||
logs[i].level = log->level;
|
||||
logs[i].length = log->length;
|
||||
char *text = malloc(log->length);
|
||||
if (!text) {
|
||||
logs[i].text = NULL;
|
||||
continue;
|
||||
}
|
||||
logs[i].text = text;
|
||||
memcpy(logs[i].text, log->text, log->length);
|
||||
}
|
||||
log_queue->size = 0;
|
||||
log_queue->tail = log_queue->head;
|
||||
|
||||
rb_native_cond_signal(&log_queue->cond);
|
||||
|
||||
rb_nativethread_lock_unlock(&log_queue->lock);
|
||||
|
||||
VALUE rb_logs = rb_ary_new2(size);
|
||||
VALUE rb_text;
|
||||
for (i = 0; i < size; i++) {
|
||||
if (!logs[i].text) {
|
||||
continue;
|
||||
}
|
||||
rb_text = rb_str_new(logs[i].text, logs[i].length);
|
||||
free(logs[i].text);
|
||||
rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text));
|
||||
}
|
||||
|
||||
return rb_logs;
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
#ifndef RUBY_WHISPER_LOG_SETTABLE_H
|
||||
#define RUBY_WHISPER_LOG_SETTABLE_H
|
||||
|
||||
#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \
|
||||
static VALUE \
|
||||
ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \
|
||||
{ \
|
||||
return ruby_whisper_log_queue_drain(&log_queue); \
|
||||
} \
|
||||
static void \
|
||||
ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \
|
||||
{ \
|
||||
ruby_whisper_log_queue_enqueue(&log_queue, level, text); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \
|
||||
{ \
|
||||
rb_iv_set(self, "@log_callback", log_callback); \
|
||||
rb_iv_set(self, "@log_callback_user_data", user_data); \
|
||||
if (NIL_P(log_callback)) { \
|
||||
log_set(NULL, NULL); \
|
||||
} else { \
|
||||
ruby_whisper_log_queue_open(&log_queue); \
|
||||
rb_funcall((mod), id_start_log_callback_thread, 0); \
|
||||
log_set(ruby_whisper_##log_queue##_log_callback, NULL); \
|
||||
} \
|
||||
return Qnil; \
|
||||
} \
|
||||
static void \
|
||||
ruby_whisper_##log_queue##_end_proc(VALUE args) \
|
||||
{ \
|
||||
ruby_whisper_log_queue_close(&log_queue); \
|
||||
VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \
|
||||
if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \
|
||||
rb_funcall(log_callback_thread, id_join, 0); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LOG_SETTABLE_INIT(log_queue, mod) \
|
||||
ruby_whisper_log_queue_initialize(&log_queue); \
|
||||
rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \
|
||||
rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \
|
||||
rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \
|
||||
rb_extend_object(mod, mLogSettable); \
|
||||
rb_funcall(mLogSettable, id_extended, 1, mod);
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
#include "ruby_whisper.h"
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
extern VALUE mParakeet;
|
||||
extern VALUE mLogSettable;
|
||||
extern VALUE cParakeetContext;
|
||||
extern VALUE cParakeetSegment;
|
||||
extern VALUE mOutputContext;
|
||||
extern VALUE mOutputSegment;
|
||||
|
||||
extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet);
|
||||
extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet);
|
||||
extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext);
|
||||
extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet);
|
||||
|
||||
static ruby_whisper_log_queue parakeet_log_queue;
|
||||
|
||||
LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_s_system_info_str(VALUE self)
|
||||
{
|
||||
return rb_str_new2(parakeet_print_system_info());
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet(VALUE *mWhisper)
|
||||
{
|
||||
mParakeet = rb_define_module_under(*mWhisper, "Parakeet");
|
||||
|
||||
rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version()));
|
||||
|
||||
LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet)
|
||||
|
||||
rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0);
|
||||
|
||||
init_ruby_whisper_parakeet_params(&mParakeet);
|
||||
init_ruby_whisper_parakeet_token(&mParakeet);
|
||||
init_ruby_whisper_parakeet_segment(&mParakeet);
|
||||
cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet);
|
||||
init_ruby_whisper_parakeet_context_params(&cParakeetContext);
|
||||
init_ruby_whisper_parakeet_model(&mParakeet);
|
||||
|
||||
rb_include_module(cParakeetContext, mOutputContext);
|
||||
rb_include_module(cParakeetSegment, mOutputSegment);
|
||||
}
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_SEGMENT_ATTRS(ITERATOR) \
|
||||
ITERATOR(get_segment_t0, LONG) \
|
||||
ITERATOR(get_segment_t1, LONG) \
|
||||
ITERATOR(get_segment_text, STRING) \
|
||||
ITERATOR(n_tokens, INT)
|
||||
|
||||
#define ITERATE_TOKEN_ATTRS(ITERATOR) \
|
||||
ITERATOR(get_token_text, STRING) \
|
||||
ITERATOR(get_token_id, INT) \
|
||||
ITERATOR(get_token_p, FLOAT)
|
||||
|
||||
#define VAL_FROM_LONG(v) LONG2NUM(v)
|
||||
#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v)
|
||||
#define VAL_FROM_INT(v) INT2NUM(v)
|
||||
#define VAL_FROM_FLOAT(v) DBL2NUM(v)
|
||||
#define READER(type) VAL_FROM_##type
|
||||
|
||||
extern ID id_to_s;
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
extern ID id_new;
|
||||
|
||||
extern VALUE cParakeetContext;
|
||||
extern VALUE eError;
|
||||
|
||||
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
|
||||
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
|
||||
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
||||
extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples);
|
||||
extern VALUE release_samples(VALUE rb_parsed_args);
|
||||
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
|
||||
extern rb_data_type_t ruby_whisper_parakeet_params_type;
|
||||
extern rb_data_type_t ruby_whisper_parakeet_context_params_type;
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
|
||||
extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context);
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_context_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
||||
if (rwpc->context) {
|
||||
parakeet_free(rwpc->context);
|
||||
rwpc->context = NULL;
|
||||
}
|
||||
xfree(rwpc);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_context_memsize(const void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
|
||||
if (!rwpc) {
|
||||
return 0;
|
||||
}
|
||||
size_t size = sizeof(*rwpc);
|
||||
return size;
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_context_type = {
|
||||
"ruby_whisper_parakeet_context",
|
||||
{0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
||||
rwpc->context = NULL;
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context **context;
|
||||
char *model_path;
|
||||
struct parakeet_context_params params;
|
||||
} ruby_whisper_parakeet_context_init_args;
|
||||
|
||||
static void*
|
||||
ruby_whisper_parakeet_context_init_without_gvl(void *args)
|
||||
{
|
||||
ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args;
|
||||
*init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
VALUE model_path;
|
||||
VALUE context_params;
|
||||
struct parakeet_context_params params;
|
||||
|
||||
rb_scan_args(argc, argv, "11", &model_path, &context_params);
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
|
||||
|
||||
model_path = ruby_whisper_normalize_model_path(model_path);
|
||||
if (!rb_respond_to(model_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context");
|
||||
}
|
||||
if (NIL_P(context_params)) {
|
||||
params = parakeet_context_default_params();
|
||||
} else {
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
GetParakeetContextParams(context_params, rwpcp);
|
||||
params = rwpcp->params;
|
||||
}
|
||||
ruby_whisper_parakeet_context_init_args init_args = {
|
||||
&rwpc->context,
|
||||
StringValueCStr(model_path),
|
||||
params,
|
||||
};
|
||||
rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL);
|
||||
if (rwpc->context == NULL) {
|
||||
rb_raise(rb_eRuntimeError, "Failed to load model");
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full_n_segments(VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
|
||||
return INT2NUM(parakeet_full_n_segments(rwpc->context));
|
||||
}
|
||||
|
||||
#define DEF_SEGMENT_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(self, rwpc); \
|
||||
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \
|
||||
}
|
||||
|
||||
ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR)
|
||||
|
||||
#define DEF_TOKEN_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(self, rwpc); \
|
||||
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \
|
||||
}
|
||||
|
||||
ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token)
|
||||
{
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token));
|
||||
|
||||
return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_each_segment(VALUE self)
|
||||
{
|
||||
if (!rb_block_given_p()) {
|
||||
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
||||
return rb_funcall(self, id_to_enum, 1, method_name);
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(self, rwpc);
|
||||
|
||||
const int n_segments = parakeet_full_n_segments(rwpc->context);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
struct parakeet_context *context;
|
||||
struct parakeet_full_params *params;
|
||||
float *samples;
|
||||
int n_samples;
|
||||
int result;
|
||||
} parakeet_full_without_gvl_args;
|
||||
|
||||
static void*
|
||||
parakeet_full_without_gvl(void *rb_args)
|
||||
{
|
||||
parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args;
|
||||
args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples);
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
|
||||
} parakeet_full_ubf_args;
|
||||
|
||||
static void
|
||||
parakeet_full_ubf(void *rb_args)
|
||||
{
|
||||
parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args;
|
||||
|
||||
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_context_full_body(VALUE rb_args)
|
||||
{
|
||||
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(*args->context, rwpc);
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetParams(*args->params, rwpp);
|
||||
|
||||
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
|
||||
0,
|
||||
NULL,
|
||||
};
|
||||
ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data);
|
||||
|
||||
parakeet_full_without_gvl_args full_without_gvl_args = {
|
||||
rwpc->context,
|
||||
&rwpp->params,
|
||||
args->samples,
|
||||
args->n_samples,
|
||||
0
|
||||
};
|
||||
parakeet_full_ubf_args full_ubf_args = {
|
||||
&abort_callback_user_data,
|
||||
};
|
||||
rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args);
|
||||
|
||||
return INT2NUM(full_without_gvl_args.result);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
if (argc < 2 || argc > 3) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
VALUE n_samples = argc == 2 ? Qnil : argv[2];
|
||||
|
||||
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
&argv[0],
|
||||
parsed.samples,
|
||||
parsed.n_samples,
|
||||
};
|
||||
VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result == 0) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
|
||||
}
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_get_model(VALUE self)
|
||||
{
|
||||
return ruby_whisper_parakeet_model_s_new(self);
|
||||
}
|
||||
|
||||
VALUE
|
||||
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
|
||||
|
||||
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
|
||||
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
|
||||
rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0);
|
||||
rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2);
|
||||
rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0);
|
||||
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
|
||||
rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1);
|
||||
|
||||
#define REGISTER_SEGMENT_ATTR(name, type) \
|
||||
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1);
|
||||
|
||||
ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR)
|
||||
|
||||
#define REGISTER_TOKEN_ATTR(name, type) \
|
||||
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2);
|
||||
|
||||
ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR)
|
||||
|
||||
return cParakeetContext;
|
||||
}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(use_gpu, BOOL) \
|
||||
ITERATOR(gpu_device, INT)
|
||||
|
||||
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
|
||||
#define VAL_TO_BOOL(v) (RTEST(v))
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_TO_INT(v) (NUM2INT(v))
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define WRITER(type) VAL_TO_##type
|
||||
|
||||
#define DEF_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
return READER(type)(rwpcp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_context_params *rwpcp; \
|
||||
GetParakeetContextParams(self, rwpcp); \
|
||||
rwpcp->params.name = WRITER(type)(val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name,
|
||||
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS
|
||||
};
|
||||
|
||||
extern VALUE cParakeetContextParams;
|
||||
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_context_params_memsize(const void *p)
|
||||
{
|
||||
if (!p) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_context_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_context_params_type = {
|
||||
"ruby_whisper_parakeet_context_params",
|
||||
{0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,},
|
||||
0, 0,
|
||||
0,
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_parakeet_context_params *rwpcp;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
|
||||
rwpcp->params = parakeet_context_default_params();
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values);
|
||||
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) {
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
param_writers[i](self, value);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext)
|
||||
{
|
||||
cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate);
|
||||
|
||||
rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1);
|
||||
|
||||
int i = 0;
|
||||
#define REGISTER_ATTR(name, type) \
|
||||
param_names[i] = rb_intern(#name); \
|
||||
param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \
|
||||
rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \
|
||||
rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \
|
||||
i++;
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(n_vocab) \
|
||||
ITERATOR(n_audio_ctx) \
|
||||
ITERATOR(n_audio_state) \
|
||||
ITERATOR(n_audio_head) \
|
||||
ITERATOR(n_audio_layer) \
|
||||
ITERATOR(n_mels) \
|
||||
ITERATOR(ftype)
|
||||
|
||||
extern rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern VALUE cParakeetModel;
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_model_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p;
|
||||
if (!NIL_P(rwpm->context)) {
|
||||
rb_gc_mark(rwpm->context);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_model_memsize(const void *p)
|
||||
{
|
||||
if (!p) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_model);
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_model_type = {
|
||||
"ruby_whisper_parakeet_model",
|
||||
{ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_model_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_model *rwpm;
|
||||
VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
|
||||
rwpm->context = Qnil;
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_model_s_new(VALUE context)
|
||||
{
|
||||
const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel);
|
||||
ruby_whisper_parakeet_model *rwpm;
|
||||
TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
|
||||
rwpm->context = context;
|
||||
return model;
|
||||
}
|
||||
|
||||
#define DEF_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_model_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_model *rwpm; \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetModel(self, rwpm); \
|
||||
GetParakeetContext(rwpm->context, rwpc); \
|
||||
return INT2NUM(parakeet_model_##name(rwpc->context)); \
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_model(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(name) \
|
||||
rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0);
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
}
|
||||
|
|
@ -0,0 +1,548 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_PARAMS(ITERATOR) \
|
||||
ITERATOR(n_threads, INT) \
|
||||
ITERATOR(offset_ms, INT) \
|
||||
ITERATOR(duration_ms, INT) \
|
||||
ITERATOR(no_context, BOOL) \
|
||||
ITERATOR(audio_ctx, INT)
|
||||
|
||||
#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \
|
||||
ITERATOR(new_segment, DATA) \
|
||||
ITERATOR(new_token, DATA) \
|
||||
ITERATOR(progress, DATA) \
|
||||
ITERATOR(encoder_begin, DATA)
|
||||
|
||||
#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback)
|
||||
#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR)
|
||||
|
||||
#define ITERATE_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
|
||||
ITERATOR(abort_callback)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name,
|
||||
#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name,
|
||||
#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data,
|
||||
ITERATE_PARAMS(DEF_IDX)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA)
|
||||
|
||||
RUBY_WHISPER_PARAKEET_NUM_PARAMS
|
||||
};
|
||||
|
||||
#define VAL_TO_INT(v) (NUM2INT(v))
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_TO_BOOL(v) (RTEST(v))
|
||||
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
|
||||
|
||||
extern VALUE cParakeetParams;
|
||||
extern ID id_call;
|
||||
|
||||
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
|
||||
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
|
||||
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
|
||||
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
|
||||
typedef VALUE (*param_writer_t)(VALUE, VALUE);
|
||||
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
int n_new;
|
||||
} call_parakeet_new_segment_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_new_segment_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
const int n_segments = parakeet_full_n_segments_from_state(args->state);
|
||||
for (int i = args->n_new; i > 0; i--) {
|
||||
int i_segment = n_segments - i;
|
||||
VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment);
|
||||
for (int j = 0; j < n_callbacks; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, segment);
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_new_segment_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
n_new,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_context *context;
|
||||
struct parakeet_state *state;
|
||||
const parakeet_token_data *token_data;
|
||||
} call_parakeet_new_token_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_new_token_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args;
|
||||
VALUE token = Qnil;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
if (NIL_P(token)) {
|
||||
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
|
||||
}
|
||||
for (int i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
rb_funcall(cb, id_call, 1, token);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_new_token_callbacks_args args = {
|
||||
container,
|
||||
context,
|
||||
state,
|
||||
token_data,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
int progress;
|
||||
} call_parakeet_progress_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_progress_callback(void *v_args)
|
||||
{
|
||||
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return;
|
||||
}
|
||||
|
||||
call_parakeet_progress_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
progress,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct parakeet_state *state;
|
||||
bool is_continued;
|
||||
} call_parakeet_encoder_begin_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_encoder_begin_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (result == Qfalse) {
|
||||
args->is_continued = false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
|
||||
{
|
||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||
if (!ruby_whisper_callback_container_is_present(container)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
call_parakeet_encoder_begin_callbacks_args args = {
|
||||
container,
|
||||
state,
|
||||
true,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
|
||||
|
||||
return args.is_continued;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_parakeet_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_parakeet_abort_callbacks(void *v_args)
|
||||
{
|
||||
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (n_callbacks == 0) {
|
||||
return NULL;
|
||||
}
|
||||
VALUE cb;
|
||||
for (long i = 0; i < n_callbacks; i++) {
|
||||
cb = rb_ary_entry(container->callbacks, i);
|
||||
result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_parakeet_abort_callback(void *user_data)
|
||||
{
|
||||
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
|
||||
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_parakeet_abort_callbacks_args args = {
|
||||
data->callback_container,
|
||||
false,
|
||||
};
|
||||
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
|
||||
|
||||
return args.is_interrupted;
|
||||
}
|
||||
|
||||
#define CALLBACK_CONTAINER_NAME(name) name ## _container
|
||||
|
||||
void
|
||||
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
|
||||
{
|
||||
#define PARAM_NAME(name) name
|
||||
#define USER_DATA_NAME(name) name##_user_data
|
||||
#define REGISTER_CALLBACK(name) \
|
||||
if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \
|
||||
rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \
|
||||
rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \
|
||||
}
|
||||
|
||||
ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK)
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) {
|
||||
abort_callback_user_data->callback_container = rwpp->abort_callback_container;
|
||||
}
|
||||
rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback;
|
||||
rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_params_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
|
||||
|
||||
#define MARK_CONTAINER(name) \
|
||||
if (rwpp->name##_container) { \
|
||||
ruby_whisper_callback_container_mark(rwpp->name##_container); \
|
||||
}
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(MARK_CONTAINER)
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_params_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
|
||||
|
||||
#define FREE_CONTAINER(name) \
|
||||
if (rwpp->name##_container) { \
|
||||
xfree(rwpp->name##_container); \
|
||||
}
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(FREE_CONTAINER)
|
||||
|
||||
xfree(rwpp);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_params_memsize(const void *p)
|
||||
{
|
||||
const struct ruby_whisper_parakeet_params *params = p;
|
||||
if (!params) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(ruby_whisper_parakeet_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_parakeet_params_type = {
|
||||
"ruby_whisper_parakeet_params",
|
||||
{ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define WRITER(type) VAL_TO_##type
|
||||
#define DEF_PARAM_ATTR(name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return READER(type)(rwpp->params.name); \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->params.name = WRITER(type)(val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_CALLBACK_PARAM_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_USER_DATA_PARAM_ATTR(name) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \
|
||||
} \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \
|
||||
return val; \
|
||||
}
|
||||
|
||||
#define DEF_HOOK(name, data) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_params_on_##name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_params *rwpp; \
|
||||
GetParakeetParams(self, rwpp); \
|
||||
const VALUE blk = rb_block_proc(); \
|
||||
if (NIL_P(rwpp->name##_callback_container->callbacks)) { \
|
||||
rwpp->name##_callback_container->callbacks = rb_ary_new(); \
|
||||
} \
|
||||
rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \
|
||||
return Qnil; \
|
||||
}
|
||||
|
||||
ITERATE_PARAMS(DEF_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR)
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_abort_on(VALUE self)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetParams(self, rwpp);
|
||||
const VALUE blk = rb_block_proc();
|
||||
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
|
||||
rwpp->abort_callback_container->callbacks = rb_ary_new();
|
||||
}
|
||||
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
|
||||
rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
return obj;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
|
||||
|
||||
#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate();
|
||||
|
||||
ITERATE_CALLBACK_PARAMS(INIT_CONTAINER)
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values);
|
||||
|
||||
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) {
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
param_writers[i](self, value);
|
||||
}
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_params(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
|
||||
rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate);
|
||||
|
||||
rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1);
|
||||
|
||||
int i = 0;
|
||||
#define REGISTER_PARAM(name) \
|
||||
param_names[i] = rb_intern(#name); \
|
||||
param_writers[i] = ruby_whisper_parakeet_params_set_##name; \
|
||||
rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \
|
||||
rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \
|
||||
i++;
|
||||
|
||||
#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name)
|
||||
#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name)
|
||||
#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data)
|
||||
|
||||
ITERATE_PARAMS(REGISTER_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR)
|
||||
ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR)
|
||||
|
||||
#define REGISTER_HOOK(name, data) \
|
||||
rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0);
|
||||
|
||||
ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _)
|
||||
|
||||
rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0);
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(start_time, t0, TIME) \
|
||||
ITERATOR(end_time, t1, TIME) \
|
||||
ITERATOR(text, text, STRING)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name,
|
||||
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
|
||||
};
|
||||
|
||||
#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10))
|
||||
#define VAL_FROM_STRING(v) (rb_str_new2(v))
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define DEF_ATTR(rb_name, c_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_get_##rb_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_segment *rwps; \
|
||||
GetParakeetSegment(self, rwps); \
|
||||
ruby_whisper_parakeet_context *rwpc; \
|
||||
GetParakeetContext(rwps->context, rwpc); \
|
||||
return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \
|
||||
}
|
||||
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
extern VALUE cParakeetSegment;
|
||||
extern VALUE sym_start_time;
|
||||
extern VALUE sym_end_time;
|
||||
extern VALUE sym_text;
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token);
|
||||
|
||||
static void
|
||||
rb_whisper_parakeet_segment_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p;
|
||||
rb_gc_mark(rwps->context);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_segment_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p;
|
||||
if (!rwps) {
|
||||
return 0;
|
||||
}
|
||||
return sizeof(*rwps);
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_segment_type = {
|
||||
"ruby_whisper_parakeet_segment",
|
||||
{rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_segment_init(VALUE context, int index)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
|
||||
const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment);
|
||||
TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
|
||||
rwps->context = context;
|
||||
rwps->index = index;
|
||||
|
||||
return segment;
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_each_token(VALUE self)
|
||||
{
|
||||
if (!rb_block_given_p()) {
|
||||
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
||||
return rb_funcall(self, id_to_enum, 1, method_name);
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
GetParakeetSegment(self, rwps);
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(rwps->context, rwpc);
|
||||
|
||||
const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys)
|
||||
{
|
||||
ruby_whisper_parakeet_segment *rwps;
|
||||
GetParakeetSegment(self, rwps);
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
GetParakeetContext(rwps->context, rwpc);
|
||||
|
||||
VALUE hash = rb_hash_new();
|
||||
long n_keys;
|
||||
if (NIL_P(keys)) {
|
||||
keys = rb_ary_new3(
|
||||
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
|
||||
sym_start_time,
|
||||
sym_end_time,
|
||||
sym_text
|
||||
);
|
||||
n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS;
|
||||
} else {
|
||||
n_keys = RARRAY_LEN(keys);
|
||||
if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) {
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < n_keys; i++) {
|
||||
VALUE key = rb_ary_entry(keys, i);
|
||||
|
||||
#define CHECK_AND_SET_KEY(rb_name, c_name, type) \
|
||||
if (key == sym_##rb_name) { \
|
||||
rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \
|
||||
}
|
||||
|
||||
ITERATE_ATTRS(CHECK_AND_SET_KEY)
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_segment(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(rb_name, c_name, type) \
|
||||
rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0);
|
||||
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
|
||||
rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0);
|
||||
rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
#include "ruby_whisper.h"
|
||||
|
||||
#define ITERATE_MEMBERS(ITERATOR) \
|
||||
ITERATOR(id, id, id, id, INT) \
|
||||
ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \
|
||||
ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \
|
||||
ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \
|
||||
ITERATOR(probability, probability, p, p, FLOAT) \
|
||||
ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \
|
||||
ITERATOR(start_time, start_time, start_time, t0, TIME) \
|
||||
ITERATOR(end_time, end_time, end_time, t1, TIME) \
|
||||
ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL)
|
||||
|
||||
#define ITERATE_ATTRS(ITERATOR) \
|
||||
ITERATOR(text, text, text, text, STRING)
|
||||
|
||||
enum {
|
||||
#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name,
|
||||
|
||||
ITERATE_MEMBERS(DEF_IDX)
|
||||
ITERATE_ATTRS(DEF_IDX)
|
||||
RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS,
|
||||
};
|
||||
|
||||
#define VAL_FROM_INT(v) (INT2NUM(v))
|
||||
#define VAL_FROM_FLOAT(v) (DBL2NUM(v))
|
||||
#define VAL_FROM_TIME(v) (LONG2NUM(v * 10))
|
||||
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
|
||||
#define VAL_FROM_STRING(v) (rb_str_new2(v))
|
||||
|
||||
#define READER(type) VAL_FROM_##type
|
||||
#define MEMBER_NAME(name) name
|
||||
#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_token *rwpt; \
|
||||
GetParakeetToken(self, rwpt); \
|
||||
return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \
|
||||
}
|
||||
|
||||
#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
static VALUE \
|
||||
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
|
||||
{ \
|
||||
ruby_whisper_parakeet_token *rwpt; \
|
||||
GetParakeetToken(self, rwpt); \
|
||||
return rwpt->p_name; \
|
||||
}
|
||||
|
||||
VALUE cParakeetToken;
|
||||
|
||||
#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key;
|
||||
|
||||
ITERATE_MEMBERS(DEC_ATTR_SYMS)
|
||||
ITERATE_ATTRS(DEC_ATTR_SYMS)
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_token_mark(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
rb_gc_mark(rwpt->text);
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_parakeet_token_free(void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
if (rwpt->token_data) {
|
||||
xfree(rwpt->token_data);
|
||||
rwpt->token_data = NULL;
|
||||
}
|
||||
xfree(rwpt);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_parakeet_token_memsize(const void *p)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
|
||||
if (!rwpt) {
|
||||
return 0;
|
||||
}
|
||||
size_t size = sizeof(*rwpt);
|
||||
if (rwpt->token_data) {
|
||||
size += sizeof(*rwpt->token_data);
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_parakeet_token_type = {
|
||||
"ruby_whisper_parakeet_token",
|
||||
{ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize},
|
||||
0, 0,
|
||||
0,
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_token_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
|
||||
|
||||
rwpt->token_data = NULL;
|
||||
rwpt->text = Qnil;
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
|
||||
{
|
||||
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
|
||||
|
||||
rwpt->token_data = ALLOC(parakeet_token_data);
|
||||
*rwpt->token_data = *token_data;
|
||||
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
|
||||
{
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
|
||||
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
|
||||
}
|
||||
|
||||
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
|
||||
// Define #text using parakeet_token_to_str or parakeet_token_to_text
|
||||
ITERATE_ATTRS(DEF_ATTR)
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys)
|
||||
{
|
||||
ruby_whisper_parakeet_token *rwpt;
|
||||
GetParakeetToken(self, rwpt);
|
||||
|
||||
VALUE hash = rb_hash_new();
|
||||
long n_keys = 0;
|
||||
|
||||
if (NIL_P(keys)) {
|
||||
VALUE attrs[] = {
|
||||
#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key,
|
||||
|
||||
ITERATE_MEMBERS(LIST_SYMS)
|
||||
ITERATE_ATTRS(LIST_SYMS)
|
||||
};
|
||||
keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs);
|
||||
n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS;
|
||||
} else {
|
||||
n_keys = RARRAY_LEN(keys);
|
||||
if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) {
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
for (long i = 0; i < n_keys; i++) {
|
||||
VALUE key = rb_ary_entry(keys, i);
|
||||
|
||||
#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \
|
||||
if (key == sym_##s_key) { \
|
||||
rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \
|
||||
}
|
||||
|
||||
ITERATE_MEMBERS(CHECK_AND_SET_KEY)
|
||||
ITERATE_ATTRS(CHECK_AND_SET_KEY)
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_parakeet_token(VALUE *mParakeet)
|
||||
{
|
||||
cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject);
|
||||
rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate);
|
||||
|
||||
#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \
|
||||
sym_##s_key = ID2SYM(rb_intern(#s_key)); \
|
||||
rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0);
|
||||
|
||||
ITERATE_MEMBERS(REGISTER_ATTR)
|
||||
ITERATE_ATTRS(REGISTER_ATTR)
|
||||
|
||||
rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
#include "ruby_whisper.h"
|
||||
#include "common-whisper.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
|
||||
extern const rb_data_type_t ruby_whisper_parakeet_params_type;
|
||||
|
||||
extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args);
|
||||
|
||||
extern ID id_to_path;
|
||||
extern ID id_new;
|
||||
|
||||
extern VALUE eError;
|
||||
|
||||
VALUE
|
||||
ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params)
|
||||
{
|
||||
if (rb_respond_to(audio_path, id_to_path)) {
|
||||
audio_path = rb_funcall(audio_path, id_to_path, 0);
|
||||
}
|
||||
|
||||
std::string fname = StringValueCStr(audio_path);
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
|
||||
if (!read_audio_data(fname, pcmf32, pcmf32s, false)) {
|
||||
rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str());
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
ruby_whisper_parakeet_context *rwpc;
|
||||
ruby_whisper_parakeet_params *rwpp;
|
||||
GetParakeetContext(self, rwpc);
|
||||
GetParakeetParams(params, rwpp);
|
||||
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
};
|
||||
VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args);
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result == 0) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -76,8 +76,8 @@ static ID id_vad;
|
|||
static ID id_vad_model_path;
|
||||
static ID id_vad_params;
|
||||
|
||||
static void
|
||||
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
|
||||
void
|
||||
ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
|
|
@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
|
|||
rb_gc_mark(rwc->callbacks);
|
||||
}
|
||||
|
||||
static ruby_whisper_callback_container*
|
||||
rb_whisper_callback_container_allocate() {
|
||||
ruby_whisper_callback_container*
|
||||
ruby_whisper_callback_container_allocate() {
|
||||
ruby_whisper_callback_container *container;
|
||||
container = ALLOC(ruby_whisper_callback_container);
|
||||
container->context = NULL;
|
||||
|
|
@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() {
|
|||
return container;
|
||||
}
|
||||
|
||||
static void
|
||||
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
rb_gc_mark(rwc->user_data);
|
||||
rb_gc_mark(rwc->callback);
|
||||
rb_gc_mark(rwc->callbacks);
|
||||
}
|
||||
|
||||
static ruby_whisper_abort_callback_container*
|
||||
rb_whisper_abort_callback_container_allocate() {
|
||||
ruby_whisper_abort_callback_container *container;
|
||||
container = ALLOC(ruby_whisper_abort_callback_container);
|
||||
container->context = NULL;
|
||||
container->user_data = Qnil;
|
||||
container->callback = Qnil;
|
||||
container->callbacks = Qnil;
|
||||
container->is_interrupted = false;
|
||||
return container;
|
||||
}
|
||||
|
||||
static bool
|
||||
bool
|
||||
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
static bool
|
||||
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
|
||||
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
|
|
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
|
|||
}
|
||||
|
||||
typedef struct {
|
||||
const ruby_whisper_abort_callback_container *container;
|
||||
struct whisper_state *state;
|
||||
const ruby_whisper_callback_container *container;
|
||||
bool is_interrupted;
|
||||
} call_abort_callbacks_args;
|
||||
|
||||
static void*
|
||||
call_abort_callbacks(void *v_args) {
|
||||
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
|
||||
const ruby_whisper_abort_callback_container *container = args->container;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
const ruby_whisper_callback_container *container = args->container;
|
||||
VALUE result = Qnil;
|
||||
|
||||
if (!NIL_P(container->callback)) {
|
||||
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
result = rb_funcall(container->callback, id_call, 1, container->user_data);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
|
|||
if (NIL_P(container->callbacks)) {
|
||||
return NULL;
|
||||
}
|
||||
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||
if (0 == callbacks_len) {
|
||||
const long n_callbacks = RARRAY_LEN(container->callbacks);
|
||||
if (0 == n_callbacks) {
|
||||
return NULL;
|
||||
}
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
for (int j = 0; j < n_callbacks; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
|
||||
if (!NIL_P(result) && Qfalse != result) {
|
||||
VALUE result = rb_funcall(cb, id_call, 0);
|
||||
if (RTEST(result)) {
|
||||
args->is_interrupted = true;
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
|
|||
}
|
||||
|
||||
static bool abort_callback(void * user_data) {
|
||||
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
|
||||
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
|
||||
|
||||
if (container->is_interrupted) {
|
||||
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
|
||||
if (is_interrupted) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!ruby_whisper_abort_callback_container_is_present(container)) {
|
||||
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
call_abort_callbacks_args args = {
|
||||
container,
|
||||
NULL,
|
||||
data->callback_container,
|
||||
false
|
||||
};
|
||||
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
|
||||
|
|
@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
|
|||
return;
|
||||
}
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription");
|
||||
}
|
||||
// new_segment_callback is called only after multiple threads are joined
|
||||
// progress_callback is not called when parallel
|
||||
|
||||
if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
|
||||
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
|
||||
}
|
||||
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
if (!NIL_P(log_callback)) {
|
||||
rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
|
||||
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
|
||||
rwp->new_segment_callback_container->context = context;
|
||||
rwp->params.new_segment_callback = new_segment_callback;
|
||||
|
|
@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
|||
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
|
||||
}
|
||||
|
||||
abort_callback_user_data->callback_container = rwp->abort_callback_container;
|
||||
rwp->abort_callback_container->context = context;
|
||||
rwp->params.abort_callback = abort_callback;
|
||||
rwp->abort_callback_container->is_interrupted = false;
|
||||
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
|
||||
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
|
||||
}
|
||||
|
||||
static void set_vad_params(ruby_whisper_params *rwp)
|
||||
|
|
@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp)
|
|||
rwp->params.vad_params = rwvp->params;
|
||||
}
|
||||
|
||||
/*
|
||||
TODO: Set abort callback to trap SIGINT and SIGTERM
|
||||
*/
|
||||
void
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
|
||||
{
|
||||
check_thread_safety(rwp, n_processors);
|
||||
register_callbacks(rwp, context);
|
||||
register_callbacks(rwp, context, abort_callback_user_data);
|
||||
set_vad_params(rwp);
|
||||
}
|
||||
|
||||
|
|
@ -421,10 +376,10 @@ void
|
|||
rb_whisper_params_mark(void *p)
|
||||
{
|
||||
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
|
||||
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
|
||||
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
|
||||
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
|
||||
rb_gc_mark(rwp->vad_params);
|
||||
}
|
||||
|
||||
|
|
@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass)
|
|||
}
|
||||
rwp->diarize = false;
|
||||
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
|
||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
|
||||
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
|
||||
return obj;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
static VALUE sym_start_time;
|
||||
static VALUE sym_end_time;
|
||||
static VALUE sym_text;
|
||||
static VALUE sym_no_speech_prob;
|
||||
static VALUE sym_speaker_turn_next;
|
||||
static VALUE sym_n_tokens;
|
||||
VALUE sym_start_time;
|
||||
VALUE sym_end_time;
|
||||
VALUE sym_text;
|
||||
VALUE sym_no_speech_prob;
|
||||
VALUE sym_speaker_turn_next;
|
||||
VALUE sym_n_tokens;
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_type;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ extern ID id_to_path;
|
|||
extern ID transcribe_option_names[1];
|
||||
|
||||
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
|
||||
extern VALUE full_body(VALUE rb_args);
|
||||
extern VALUE full_parallel_body(VALUE rb_args);
|
||||
|
||||
typedef struct{
|
||||
struct whisper_context *context;
|
||||
|
|
@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
ruby_whisper_abort_callback_container *abort_callback_container;
|
||||
} transcribe_ubf_args;
|
||||
|
||||
static void
|
||||
transcribe_ubf(void *rb_args)
|
||||
{
|
||||
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
|
||||
|
||||
args->abort_callback_container->is_interrupted = true;
|
||||
}
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
|
|
@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
// Commented out because it is work in progress
|
||||
// {
|
||||
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
// bool is_aborted = *(bool*)user_data;
|
||||
// return !is_aborted;
|
||||
// };
|
||||
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
// }
|
||||
|
||||
prepare_transcription(rwp, &self, n_processors);
|
||||
|
||||
transcribe_without_gvl_args args = {
|
||||
rw->context,
|
||||
&rwp->params,
|
||||
pcmf32.data(),
|
||||
pcmf32.size(),
|
||||
n_processors,
|
||||
0,
|
||||
};
|
||||
transcribe_ubf_args ubf_args = {
|
||||
rwp->abort_callback_container,
|
||||
};
|
||||
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
|
||||
if (args.result != 0) {
|
||||
VALUE rb_result;
|
||||
if (n_processors == 1) {
|
||||
ruby_whisper_full_args args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
};
|
||||
rb_result = full_body((VALUE)&args);
|
||||
} else {
|
||||
ruby_whisper_full_parallel_args parallel_args = {
|
||||
&self,
|
||||
¶ms,
|
||||
pcmf32.data(),
|
||||
(int)pcmf32.size(),
|
||||
n_processors,
|
||||
};
|
||||
rb_result = full_parallel_body((VALUE)¶llel_args);
|
||||
}
|
||||
const int result = NUM2INT(rb_result);
|
||||
if (result != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
module Whisper
|
||||
class Context
|
||||
def to_srt
|
||||
each_segment.with_index.reduce("") {|srt, (segment, index)|
|
||||
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
|
||||
}
|
||||
end
|
||||
|
||||
def to_webvtt
|
||||
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
|
||||
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
require "mutex_m"
|
||||
|
||||
module Whisper
|
||||
module LogSettable
|
||||
class << self
|
||||
def extended(base)
|
||||
base.extend Mutex_m
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def start_log_callback_thread
|
||||
return if @log_callback_thread&.alive?
|
||||
|
||||
@log_callback_thread = Thread.new {
|
||||
begin
|
||||
while logs = drain_logs
|
||||
begin
|
||||
callback, user_data = synchronize {[@log_callback, @log_callback_user_data]}
|
||||
next if callback.nil?
|
||||
|
||||
logs.each do |(level, text)|
|
||||
callback.call level, text, user_data
|
||||
end
|
||||
rescue => err
|
||||
$stderr.puts err
|
||||
end
|
||||
end
|
||||
rescue => err
|
||||
$stderr.puts err
|
||||
end
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -41,6 +41,8 @@ module Whisper
|
|||
|
||||
def cache
|
||||
path = cache_path
|
||||
return path if cache_path.exist?
|
||||
|
||||
headers = {}
|
||||
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
||||
request @uri, headers
|
||||
|
|
@ -216,8 +218,18 @@ module Whisper
|
|||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
%w[
|
||||
parakeet-tdt-0.6b-v3-f16
|
||||
parakeet-tdt-0.6b-v3-f32
|
||||
parakeet-tdt-0.6b-v3-q4_0
|
||||
parakeet-tdt-0.6b-v3-q4_k
|
||||
parakeet-tdt-0.6b-v3-q8_0
|
||||
].each do |name|
|
||||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
@coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models|
|
||||
next if name.end_with?("-tdrz") || name.start_with?("silero-")
|
||||
next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-")
|
||||
|
||||
if matched = name.match(/\A(?<name>.*)-q\d_\d\z/)
|
||||
name = matched[:name]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
module Whisper
|
||||
module Output
|
||||
module Context
|
||||
def to_srt
|
||||
each_segment.with_index.reduce("") {|srt, (segment, index)|
|
||||
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
|
||||
}
|
||||
end
|
||||
|
||||
def to_webvtt
|
||||
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
|
||||
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
module Segment
|
||||
SRT_ESCAPES = {
|
||||
"&" => "&",
|
||||
"<" => "<",
|
||||
">" => ">",
|
||||
}
|
||||
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
|
||||
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
|
||||
|
||||
def to_srt_cue
|
||||
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
|
||||
end
|
||||
|
||||
def to_webvtt_cue
|
||||
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def time_to_a(time)
|
||||
sec, decimal_part = time.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
[hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
def srt_time(time)
|
||||
"%02d:%02d:%02d,%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def srt_start_time
|
||||
srt_time(start_time)
|
||||
end
|
||||
|
||||
def srt_end_time
|
||||
srt_time(end_time)
|
||||
end
|
||||
|
||||
def srt_text
|
||||
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
|
||||
end
|
||||
|
||||
def webvtt_time(time)
|
||||
"%02d:%02d:%02d.%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def webvtt_start_time
|
||||
webvtt_time(start_time)
|
||||
end
|
||||
|
||||
def webvtt_end_time
|
||||
webvtt_time(end_time)
|
||||
end
|
||||
|
||||
alias webvtt_text srt_text
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
module Whisper
|
||||
class Segment
|
||||
SRT_ESCAPES = {
|
||||
"&" => "&",
|
||||
"<" => "<",
|
||||
">" => ">",
|
||||
}
|
||||
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
|
||||
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
|
||||
|
||||
def to_srt_cue
|
||||
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
|
||||
end
|
||||
|
||||
def to_webvtt_cue
|
||||
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def time_to_a(time)
|
||||
sec, decimal_part = time.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
[hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
def srt_time(time)
|
||||
"%02d:%02d:%02d,%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def srt_start_time
|
||||
srt_time(start_time)
|
||||
end
|
||||
|
||||
def srt_end_time
|
||||
srt_time(end_time)
|
||||
end
|
||||
|
||||
def srt_text
|
||||
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
|
||||
end
|
||||
|
||||
def webvtt_time(time)
|
||||
"%02d:%02d:%02d.%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def webvtt_start_time
|
||||
webvtt_time(start_time)
|
||||
end
|
||||
|
||||
def webvtt_end_time
|
||||
webvtt_time(end_time)
|
||||
end
|
||||
|
||||
alias webvtt_text srt_text
|
||||
end
|
||||
end
|
||||
|
|
@ -40,7 +40,21 @@ module Whisper
|
|||
def self.log_set: (log_callback?, Object? user_data) -> log_callback
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
module Output
|
||||
module Context
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
end
|
||||
|
||||
module Segment
|
||||
def to_srt_cue: () -> String
|
||||
def to_webvtt_cue: () -> String
|
||||
end
|
||||
end
|
||||
|
||||
class Context
|
||||
include Output::Context
|
||||
|
||||
def self.new: (String | path | ::URI::HTTP) -> instance
|
||||
|
||||
# transcribe a single file
|
||||
|
|
@ -139,17 +153,14 @@ module Whisper
|
|||
| (Whisper::Params, _Samples, ?Integer n_samples) -> self
|
||||
| (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
use_gpu: boolish,
|
||||
flash_attn: boolish,
|
||||
gpu_device: Integer,
|
||||
dtw_token_timestamps: boolish,
|
||||
dtw_aheads_preset: Integer,
|
||||
dtw_n_top: Integer | nil,
|
||||
?use_gpu: boolish,
|
||||
?flash_attn: boolish,
|
||||
?gpu_device: Integer,
|
||||
?dtw_token_timestamps: boolish,
|
||||
?dtw_aheads_preset: Integer,
|
||||
?dtw_n_top: Integer | nil,
|
||||
) -> instance
|
||||
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
|
|
@ -444,6 +455,9 @@ module Whisper
|
|||
def abort_on: { (Object user_data) -> boolish } -> void
|
||||
end
|
||||
|
||||
module LogSettable
|
||||
end
|
||||
|
||||
class Model
|
||||
def self.pre_converted_models: () -> Hash[String, Model::URI]
|
||||
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
|
||||
|
|
@ -474,6 +488,8 @@ module Whisper
|
|||
end
|
||||
|
||||
class Segment
|
||||
include Output::Segment
|
||||
|
||||
type deconstructed_keys = {
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
|
|
@ -514,9 +530,6 @@ module Whisper
|
|||
#
|
||||
def each_token: { (Token) -> void } -> void
|
||||
| () -> Enumerator[Token]
|
||||
def to_srt_cue: () -> String
|
||||
def to_webvtt_cue: () -> String
|
||||
|
||||
|
||||
# Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next`
|
||||
#
|
||||
|
|
@ -528,7 +541,7 @@ module Whisper
|
|||
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
module Token
|
||||
class Token
|
||||
type deconstructed_keys = {
|
||||
id: (Integer | nil),
|
||||
tid: (Integer | nil),
|
||||
|
|
@ -598,6 +611,336 @@ module Whisper
|
|||
def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
module Parakeet
|
||||
extend LogSettable
|
||||
|
||||
VERSION: String
|
||||
|
||||
# Control logging output. The default behavior is to print to stderr.
|
||||
#
|
||||
def self.log_set: (nil, Object? user_data) -> nil
|
||||
| (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
class Context
|
||||
include Output::Context
|
||||
|
||||
# Load a Parakeet model from the given file path.
|
||||
#
|
||||
def self.new: (String | path | ::URI::HTTP, ?Params) -> instance
|
||||
|
||||
# Transcribe a single audio file.
|
||||
#
|
||||
def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self
|
||||
|
||||
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
# Not thread safe for the same context.
|
||||
#
|
||||
# The second argument `samples` must be an array of samples, respond to `:length`,
|
||||
# or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
|
||||
#
|
||||
def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self
|
||||
| (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self
|
||||
|
||||
# Number of generated text segments.
|
||||
#
|
||||
def full_n_segments: () -> Integer
|
||||
|
||||
# Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t0: (Integer segment_index) -> Integer
|
||||
|
||||
# End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t1: (Integer segment_index) -> Integer
|
||||
|
||||
# Text of a segment indexed by `segment_index`.
|
||||
#
|
||||
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||
#
|
||||
def full_get_segment_text: (Integer segment_index) -> String
|
||||
|
||||
# Number of tokens in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_n_tokens: (Integer segment_index) -> Integer
|
||||
|
||||
# Text of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_text: (Integer segment_index, Integer token_index) -> String
|
||||
|
||||
# Token id of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer
|
||||
|
||||
# Probability of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_p: (Integer segment_index, Integer token_index) -> Float
|
||||
|
||||
# Token data of the token indexed by `token_index` in the segment indexed by `segment_index`.
|
||||
#
|
||||
def full_get_token_data: (Integer segment_index, Integer token_index) -> Token
|
||||
|
||||
def model: () -> Model
|
||||
|
||||
# Yields each Whisper::Parakeet::Segment:
|
||||
#
|
||||
# parakeet.transcribe("path/to/audio.wav", params)
|
||||
# parakeet.each_segment do |segment|
|
||||
# puts segment.text
|
||||
# end
|
||||
#
|
||||
# Returns an `Enumerator` if no block given:
|
||||
#
|
||||
# parakeet.transcribe("path/to/audio.wav", params)
|
||||
# enum = parakeet.each_segment
|
||||
# enum.to_a # => [#<Whisper::Parakeet::Segment>, ...]
|
||||
#
|
||||
def each_segment: { (Segment) -> void } -> void
|
||||
| () -> Enumerator[Segment]
|
||||
|
||||
class Params
|
||||
def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance
|
||||
def use_gpu: () -> boolish
|
||||
def use_gpu=: (boolish) -> boolish
|
||||
def gpu_device: () -> Integer
|
||||
def gpu_device=: (Integer) -> Integer
|
||||
end
|
||||
end
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
?n_threads: Integer,
|
||||
?offset_ms: Integer,
|
||||
?duration_ms: Integer,
|
||||
?no_context: boolish,
|
||||
?audio_ctx: Integer,
|
||||
?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void,
|
||||
?new_segment_callback_user_data: Object,
|
||||
?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void,
|
||||
?new_token_callback_user_data: Object,
|
||||
?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void,
|
||||
?progress_callback_user_data: Object,
|
||||
?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish,
|
||||
?encoder_begin_callback_user_data: Object,
|
||||
?abort_callback: ^(Object user_data) -> boolish,
|
||||
?abort_callback_user_data: Object
|
||||
) -> instance
|
||||
|
||||
# Number of threads to use.
|
||||
#
|
||||
def n_threads=: (Integer) -> Integer
|
||||
def n_threads: () -> Integer
|
||||
|
||||
# Start offset in ms.
|
||||
#
|
||||
def offset_ms=: (Integer) -> Integer
|
||||
def offset_ms: () -> Integer
|
||||
|
||||
# Audio duration to process in ms.
|
||||
#
|
||||
def duration_ms=: (Integer) -> Integer
|
||||
def duration_ms: () -> Integer
|
||||
|
||||
# If `true`, does not use past transcription (if any) as context.
|
||||
#
|
||||
def no_context=: (boolish) -> boolish
|
||||
def no_context: () -> (true | false)
|
||||
|
||||
# Overwrite the audio context size. `0` uses the default value.
|
||||
#
|
||||
def audio_ctx=: (Integer) -> Integer
|
||||
def audio_ctx: () -> Integer
|
||||
|
||||
# Sets new segment callback, called for every newly generated text segment.
|
||||
#
|
||||
# params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
#
|
||||
def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void)
|
||||
def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of new segment callback.
|
||||
#
|
||||
def new_segment_callback_user_data=: (Object?) -> Object?
|
||||
def new_segment_callback_user_data: () -> Object?
|
||||
|
||||
# Sets token callback, called for every newly predicted token.
|
||||
#
|
||||
def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void)
|
||||
def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of token callback.
|
||||
#
|
||||
def new_token_callback_user_data=: (Object?) -> Object?
|
||||
def new_token_callback_user_data: () -> Object?
|
||||
|
||||
# Sets progress callback, called on each progress update.
|
||||
#
|
||||
# +progress+ is an Integer between 0 and 100.
|
||||
#
|
||||
def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void)
|
||||
def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of progress callback.
|
||||
#
|
||||
def progress_callback_user_data=: (Object?) -> Object?
|
||||
def progress_callback_user_data: () -> Object?
|
||||
|
||||
# Sets encoder begin callback, called each time before the encoder starts.
|
||||
#
|
||||
# If it returns `false`, the computation is aborted.
|
||||
#
|
||||
def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish)
|
||||
def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of encoder begin callback.
|
||||
#
|
||||
def encoder_begin_callback_user_data=: (Object?) -> Object?
|
||||
def encoder_begin_callback_user_data: () -> Object?
|
||||
|
||||
# Sets abort callback, called each time before ggml computation starts.
|
||||
#
|
||||
def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish)
|
||||
def abort_callback: () -> ((^(Object user_data) -> boolish) | nil)
|
||||
|
||||
# Sets user data passed to the last argument of abort callback.
|
||||
#
|
||||
def abort_callback_user_data=: (Object?) -> Object?
|
||||
def abort_callback_user_data: () -> Object?
|
||||
|
||||
# Hook called on new segment. Yields each Whisper::Parakeet::Segment.
|
||||
#
|
||||
def on_new_segment: { (Segment) -> void } -> void
|
||||
|
||||
# Hook called on new token. Yields each Whisper::Parakeet::Token.
|
||||
#
|
||||
def on_new_token: { (Token) -> void } -> void
|
||||
|
||||
# Hook called on progress update. Yields each progress `Integer` between 0 and 100.
|
||||
#
|
||||
def on_progress: { (Integer progress) -> void } -> void
|
||||
|
||||
# Hook called each time before the encoder starts.
|
||||
#
|
||||
def on_encoder_begin: { () -> boolish } -> void
|
||||
|
||||
# Call block to determine whether abort or not. Return `true` when you want to abort.
|
||||
#
|
||||
def abort_on: { () -> boolish } -> void
|
||||
end
|
||||
|
||||
class Segment
|
||||
include Output::Segment
|
||||
|
||||
type deconstructed_keys = {
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
text: (String | nil)
|
||||
}
|
||||
|
||||
# Start time in milliseconds.
|
||||
#
|
||||
def start_time: () -> Integer
|
||||
|
||||
# End time in milliseconds.
|
||||
#
|
||||
def end_time: () -> Integer
|
||||
|
||||
# Text of the segment.
|
||||
#
|
||||
def text: () -> String
|
||||
|
||||
# Yields each Whisper::Parakeet::Token:
|
||||
#
|
||||
# parakeet.each_segment.first.each_token do |token|
|
||||
# p token
|
||||
# end
|
||||
#
|
||||
# Returns an `Enumerator` if no block is given:
|
||||
#
|
||||
# parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...]
|
||||
#
|
||||
def each_token: { (Token) -> void } -> void
|
||||
| () -> Enumerator[Token]
|
||||
|
||||
# Possible keys: `:start_time`, `:end_time`, `:text`
|
||||
#
|
||||
def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
class Token
|
||||
type deconstructed_keys = {
|
||||
id: (Integer | nil),
|
||||
duration_idx: (Integer | nil),
|
||||
duration_value: (Integer | nil),
|
||||
frame_index: (Integer | nil),
|
||||
probability: (Float | nil),
|
||||
log_probability: (Float | nil),
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
word_start: ((true | false) | nil),
|
||||
text: (String | nil),
|
||||
}
|
||||
|
||||
# Token ID.
|
||||
#
|
||||
def id: () -> Integer
|
||||
|
||||
# Index into the model's durations array.
|
||||
#
|
||||
def duration_idx: () -> Integer
|
||||
|
||||
# Actual duration value.
|
||||
#
|
||||
def duration_value: () -> Integer
|
||||
|
||||
# Frame index of the token.
|
||||
#
|
||||
def frame_index: () -> Integer
|
||||
|
||||
# Probability of the token.
|
||||
#
|
||||
def probability: () -> Float
|
||||
|
||||
# Log probability of the token.
|
||||
#
|
||||
def log_probability: () -> Float
|
||||
|
||||
# Start time of the token in milliseconds.
|
||||
#
|
||||
def start_time: () -> Integer
|
||||
|
||||
# End time of the token in milliseconds.
|
||||
#
|
||||
def end_time: () -> Integer
|
||||
|
||||
# Whether this token is the start of a word.
|
||||
#
|
||||
def word_start?: () -> (true | false)
|
||||
|
||||
# Get the token text of the token.
|
||||
#
|
||||
def text: () -> String
|
||||
|
||||
def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
class Model
|
||||
def n_vocab: () -> Integer
|
||||
def n_audio_ctx: () -> Integer
|
||||
def n_audio_state: () -> Integer
|
||||
def n_audio_head: () -> Integer
|
||||
def n_audio_layer: () -> Integer
|
||||
def n_mels: () -> Integer
|
||||
def ftype: () -> Integer
|
||||
end
|
||||
end
|
||||
|
||||
module VAD
|
||||
class Params
|
||||
def self.new: (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ require_relative "jfk_reader/jfk_reader"
|
|||
class TestBase < Test::Unit::TestCase
|
||||
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
|
||||
|
||||
Parakeet = Whisper::Parakeet
|
||||
|
||||
class << self
|
||||
def whisper
|
||||
return @whisper if @whisper
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class TestCallback < TestBase
|
|||
return false
|
||||
}
|
||||
@whisper.transcribe(@audio, @params)
|
||||
sleep 0.5 # wait for logs dequeued
|
||||
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
|
||||
Whisper.log_set ->(level, buffer, user_data) {}, nil
|
||||
end
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
require_relative "helper"
|
||||
require "stringio"
|
||||
|
||||
class TestParakeet < TestBase
|
||||
def test_log_set
|
||||
log_callback = Parakeet.instance_variable_get("@log_callback")
|
||||
user_data = Parakeet.instance_variable_get("@log_callback_user_data")
|
||||
|
||||
$stdout = StringIO.new
|
||||
Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil
|
||||
Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
sleep 0.1
|
||||
$stdout.rewind
|
||||
logs = $stdout.string
|
||||
assert_match /loading model from/, logs
|
||||
ensure
|
||||
$stdout = STDOUT
|
||||
Parakeet.log_set log_callback, user_data
|
||||
end
|
||||
|
||||
def test_system_info_str
|
||||
assert_match /\APARAKEET : /, Parakeet.system_info_str
|
||||
end
|
||||
|
||||
def test_version
|
||||
assert_instance_of String, Parakeet::VERSION
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetCallback < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
@params = Parakeet::Params.new
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
end
|
||||
|
||||
def test_new_segment_callback
|
||||
@params.new_segment_callback = ->(context, state, n_new, user_data) {
|
||||
assert_kind_of Integer, n_new
|
||||
assert n_new > 0
|
||||
assert_same @parakeet, context
|
||||
|
||||
n_segments = context.full_n_segments
|
||||
n_new.times do |i|
|
||||
i_segment = n_segments - 1 + i
|
||||
start_time = context.full_get_segment_t0(i_segment) * 10
|
||||
end_time = context.full_get_segment_t1(i_segment) * 10
|
||||
text = context.full_get_segment_text(i_segment)
|
||||
|
||||
assert_kind_of Integer, start_time
|
||||
assert start_time >= 0
|
||||
assert_kind_of Integer, end_time
|
||||
assert end_time > 0
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0
|
||||
end
|
||||
}
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
seg = nil
|
||||
index = 0
|
||||
@params.on_new_segment do |segment|
|
||||
assert_instance_of Parakeet::Segment, segment
|
||||
if index == 0
|
||||
seg = segment
|
||||
assert_equal 0, segment.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text)
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
assert_equal 0, seg.start_time
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text
|
||||
end
|
||||
|
||||
def test_on_new_token
|
||||
index = 0
|
||||
@params.on_new_token do |token|
|
||||
assert_instance_of Parakeet::Token, token
|
||||
if index == 0
|
||||
assert_instance_of Integer, token.start_time
|
||||
assert_match "▁And", token.text
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
|
||||
def test_on_progress
|
||||
first = nil
|
||||
@params.on_progress do |progress|
|
||||
assert_kind_of Integer, progress
|
||||
assert 0 <= progress && progress <= 100
|
||||
first = progress if first.nil?
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
|
||||
assert_equal 0, first
|
||||
end
|
||||
|
||||
def test_on_encoder_begin
|
||||
i = 0
|
||||
@params.on_encoder_begin do
|
||||
i += 1
|
||||
end
|
||||
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
|
||||
assert i > 0
|
||||
end
|
||||
|
||||
def test_abort_on
|
||||
do_abort = false
|
||||
@params.on_new_segment do |segment|
|
||||
do_abort = true if segment.text.match?(/ask/)
|
||||
end
|
||||
i = 0
|
||||
@params.abort_on do
|
||||
i += 1
|
||||
do_abort
|
||||
end
|
||||
|
||||
@parakeet.transcribe(AUDIO, @params) rescue nil
|
||||
|
||||
assert i > 0
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
require_relative "helper"
|
||||
require "stringio"
|
||||
|
||||
class TestParakeetContext < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
@params = Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Context, @parakeet
|
||||
end
|
||||
|
||||
def test_new_with_params
|
||||
log_callback = Parakeet.instance_variable_get(:@log_callback)
|
||||
user_data = Parakeet.instance_variable_get(:@log_callback_user_data)
|
||||
begin
|
||||
logs = ""
|
||||
Parakeet.log_set proc {|level, message| logs << message}, nil
|
||||
params = Parakeet::Context::Params.new(use_gpu: false)
|
||||
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params)
|
||||
assert_instance_of Parakeet::Context, parakeet
|
||||
assert_match /use gpu\s+=\s+0/, logs
|
||||
ensure
|
||||
Parakeet.log_set log_callback, user_data
|
||||
end
|
||||
end
|
||||
|
||||
sub_test_case "full" do
|
||||
def setup
|
||||
super
|
||||
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
|
||||
end
|
||||
|
||||
def test_full
|
||||
@parakeet.full @params, @samples, @samples.length
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text
|
||||
end
|
||||
|
||||
def test_full_without_length
|
||||
@parakeet.full(@params, @samples)
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_enumerator
|
||||
samples = @samples.each
|
||||
@parakeet.full @params, samples, @samples.length
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_enumerator_without_length
|
||||
samples = @samples.each
|
||||
assert_raise ArgumentError do
|
||||
@parakeet.full @params, samples
|
||||
end
|
||||
end
|
||||
|
||||
def test_full_enumerator_with_too_large_length
|
||||
samples = @samples.each.take(10).to_enum
|
||||
assert_raise StopIteration do
|
||||
@parakeet.full @params, samples, 11
|
||||
end
|
||||
end
|
||||
|
||||
def test_full_with_memory_view
|
||||
samples = JFKReader.new(AUDIO)
|
||||
@parakeet.full @params, samples
|
||||
|
||||
segments = @parakeet.each_segment.to_a
|
||||
assert_equal 1, segments.length
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
|
||||
end
|
||||
|
||||
def test_full_with_memroy_view_gc
|
||||
samples = JFKReader.new(AUDIO)
|
||||
@parakeet.full(@params, samples)
|
||||
GC.start
|
||||
require "fiddle"
|
||||
Fiddle::MemoryView.export samples do |view|
|
||||
assert_equal 176000, view.to_s.unpack("#{view.format}*").length
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe
|
||||
assert_nothing_raised do
|
||||
@parakeet.transcribe AUDIO, @params
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe_with_pathname
|
||||
assert_nothing_raised do
|
||||
@parakeet.transcribe Pathname(AUDIO), @params
|
||||
end
|
||||
end
|
||||
|
||||
def test_transcribe_with_nothing
|
||||
assert_raise_message(/open/) do
|
||||
@parakeet.transcribe "nothing", @params
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetContextParams < TestBase
|
||||
def setup
|
||||
@params = Parakeet::Context::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Context::Params, @params
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
assert_true @params.use_gpu
|
||||
assert_instance_of Integer, @params.gpu_device
|
||||
end
|
||||
|
||||
def test_attribute_writer
|
||||
@params.use_gpu = false
|
||||
assert_false @params.use_gpu
|
||||
|
||||
@params.gpu_device = 2
|
||||
assert_equal 2, @params.gpu_device
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetModel < TestBase
|
||||
def test_model
|
||||
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
assert_instance_of Parakeet::Model, parakeet.model
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
|
||||
model = parakeet.model
|
||||
|
||||
assert_equal 10, model.n_vocab
|
||||
assert_equal 3200, model.n_audio_ctx
|
||||
assert_equal 8, model.n_audio_state
|
||||
assert_equal 2, model.n_audio_head
|
||||
assert_equal 1, model.n_audio_layer
|
||||
assert_equal 16, model.n_mels
|
||||
assert_equal 0, model.ftype
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
require_relative "helper"
|
||||
require "etc"
|
||||
|
||||
class TestParakeetParams < TestBase
|
||||
PARAM_NAMES = [
|
||||
:n_threads,
|
||||
:offset_ms,
|
||||
:duration_ms,
|
||||
:no_context,
|
||||
:audio_ctx
|
||||
]
|
||||
|
||||
def setup
|
||||
@params = Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
assert_instance_of Parakeet::Params, @params
|
||||
end
|
||||
|
||||
def test_n_threads
|
||||
assert_equal [4, Etc.nprocessors].min, @params.n_threads
|
||||
|
||||
@params.n_threads = 1
|
||||
assert_equal 1, @params.n_threads
|
||||
end
|
||||
|
||||
def test_offset_ms
|
||||
assert_equal 0, @params.offset_ms
|
||||
|
||||
@params.offset_ms = 10_000
|
||||
assert_equal 10_000, @params.offset_ms
|
||||
end
|
||||
|
||||
def test_duration_ms
|
||||
assert_equal 0, @params.duration_ms
|
||||
|
||||
@params.duration_ms = 60_000
|
||||
assert_equal 60_000, @params.duration_ms
|
||||
end
|
||||
|
||||
def test_no_context
|
||||
assert_equal true, @params.no_context
|
||||
|
||||
@params.no_context = false
|
||||
assert_equal false, @params.no_context
|
||||
end
|
||||
|
||||
def test_audio_ctx
|
||||
assert_equal 0, @params.audio_ctx
|
||||
|
||||
@params.audio_ctx = 1
|
||||
assert_equal 1, @params.audio_ctx
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Parakeet::Params.new(n_threads: 1)
|
||||
assert_equal 1, params.n_threads
|
||||
assert_equal 0, params.offset_ms
|
||||
end
|
||||
|
||||
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
|
||||
def test_new_with_kw_args_default_values(param)
|
||||
default_value = @params.send(param)
|
||||
value = case [param, default_value]
|
||||
in [*, true | false]
|
||||
!default_value
|
||||
in [*, Integer]
|
||||
default_value + 1
|
||||
end
|
||||
params = Parakeet::Params.new(param => value)
|
||||
assert_equal value, params.send(param)
|
||||
|
||||
PARAM_NAMES.reject {|name| name == param}.each do |name|
|
||||
assert_equal @params.send(name), params.send(name)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetSegment < TestBase
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
@parakeet.transcribe AUDIO, Parakeet::Params.new
|
||||
end
|
||||
|
||||
def test_segment
|
||||
whole_text = ""
|
||||
@parakeet.each_segment do |segment|
|
||||
assert_instance_of Parakeet::Segment, segment
|
||||
assert_kind_of Integer, segment.start_time
|
||||
assert segment.end_time >= segment.start_time
|
||||
assert_kind_of String, segment.text
|
||||
whole_text << segment.text
|
||||
end
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys
|
||||
segment = @parakeet.each_segment.first
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text
|
||||
}
|
||||
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text])
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_nil
|
||||
segment = @parakeet.each_segment.first
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text
|
||||
}
|
||||
assert_equal expected, segment.deconstruct_keys(nil)
|
||||
end
|
||||
end
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
require_relative "helper"
|
||||
|
||||
class TestParakeetToken < TestBase
|
||||
ATTRS = %i[
|
||||
id
|
||||
duration_idx
|
||||
duration_value
|
||||
frame_index
|
||||
probability
|
||||
log_probability
|
||||
start_time
|
||||
end_time
|
||||
word_start?
|
||||
text
|
||||
]
|
||||
|
||||
def setup
|
||||
omit "Skip not to download large model" if ENV["CI"]
|
||||
|
||||
Whisper.instance_variable_set "@whisper", nil
|
||||
GC.start
|
||||
|
||||
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
|
||||
params = Parakeet::Params.new
|
||||
parakeet.transcribe AUDIO, params
|
||||
@segment = parakeet.each_segment.first
|
||||
end
|
||||
|
||||
def test_each_token
|
||||
i = 0
|
||||
@segment.each_token do |token|
|
||||
i += 1
|
||||
assert_instance_of Parakeet::Token, token
|
||||
end
|
||||
assert_equal 38, i
|
||||
end
|
||||
|
||||
def test_each_token_without_block
|
||||
assert_instance_of Enumerator, @segment.each_token
|
||||
end
|
||||
|
||||
def test_token
|
||||
token = @segment.each_token.first
|
||||
|
||||
assert_instance_of Parakeet::Token, token
|
||||
assert_instance_of Integer, token.id
|
||||
assert_instance_of Integer, token.duration_idx
|
||||
assert_instance_of Integer, token.duration_value
|
||||
assert_instance_of Integer, token.frame_index
|
||||
assert_instance_of Float, token.probability
|
||||
assert_instance_of Float, token.log_probability
|
||||
assert_instance_of Integer, token.start_time
|
||||
assert_instance_of Integer, token.end_time
|
||||
assert_instance_of String, token.text
|
||||
end
|
||||
|
||||
def test_text
|
||||
assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."],
|
||||
@segment.each_token.collect(&:text)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_nil
|
||||
token = @segment.each_token.first
|
||||
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
|
||||
assert_equal expected, token.deconstruct_keys(nil)
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_with_keys
|
||||
token = @segment.each_token.first
|
||||
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
|
||||
assert_equal expected, token.deconstruct_keys(expected.keys)
|
||||
end
|
||||
end
|
||||
|
|
@ -9,7 +9,7 @@ class TestVADSegment < TestBase
|
|||
end
|
||||
|
||||
assert_raise do
|
||||
segments.end_time
|
||||
segment.end_time
|
||||
end
|
||||
|
||||
assert_raise do
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ class TestWhisper < TestBase
|
|||
}
|
||||
Whisper.log_set log_callback, user_data
|
||||
Whisper::Context.new("base.en")
|
||||
sleep 0.1 # wait for logs dequeued
|
||||
|
||||
assert logs.length > 30
|
||||
logs.each do |log|
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Gem::Specification.new do |s|
|
|||
s.test_files = s.files.select {|file| file.start_with? "test/"}
|
||||
|
||||
s.extensions << 'ext/extconf.rb'
|
||||
s.required_ruby_version = '>= 3.1.0'
|
||||
s.required_ruby_version = '>= 3.3.0'
|
||||
|
||||
#### Documentation and testing.
|
||||
s.homepage = 'https://github.com/ggml-org/whisper.cpp'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@)
|
||||
set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@)
|
||||
set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@)
|
||||
set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@")
|
||||
set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake)
|
||||
|
||||
find_library(parakeet_LIBRARY parakeet
|
||||
REQUIRED
|
||||
HINTS ${PARAKEET_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
|
||||
add_library(parakeet UNKNOWN IMPORTED)
|
||||
set_target_properties(parakeet
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}"
|
||||
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${parakeet_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES cxx_std_11
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
check_required_components(parakeet)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
prefix=@CMAKE_INSTALL_PREFIX@
|
||||
exec_prefix=${prefix}
|
||||
libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
|
||||
includedir=${prefix}/include
|
||||
|
||||
Name: parakeet
|
||||
Description: Port of NVIDIA's Parakeet model in C/C++
|
||||
Version: @PROJECT_VERSION@
|
||||
Libs: -L${libdir} -lggml -lggml-base -lparakeet
|
||||
Cflags: -I${includedir}
|
||||
|
|
@ -107,6 +107,8 @@ else()
|
|||
add_subdirectory(server)
|
||||
add_subdirectory(quantize)
|
||||
add_subdirectory(vad-speech-segments)
|
||||
add_subdirectory(parakeet-cli)
|
||||
add_subdirectory(parakeet-quantize)
|
||||
if (WHISPER_SDL2)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
set(TARGET parakeet-cli)
|
||||
add_executable(${TARGET} parakeet-cli.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
# whisper.cpp/examples/parakeet-cli
|
||||
|
||||
This is an example of using the [Parakeet] model in whisper.cpp.
|
||||
|
||||
### Download converted model
|
||||
```console
|
||||
$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models
|
||||
```
|
||||
|
||||
### Building
|
||||
```console
|
||||
$ cmake -B build -S .
|
||||
$ cmake --build build --target parakeet-cli -j 12
|
||||
```
|
||||
|
||||
### Usage
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli --help
|
||||
|
||||
usage: ./build/bin/parakeet-cli [options] file0 file1 ...
|
||||
supported audio formats: flac, mp3, ogg, wav
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path
|
||||
-f, --file FILE [ ] input audio file
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
-dev N, --device N [0 ] GPU device to use
|
||||
-ps, --print-segments [false ] print segment information
|
||||
```
|
||||
|
||||
### Example
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav
|
||||
Processing audio (176000 samples, 11.00 seconds)
|
||||
Processing audio: total_frames=1101, chunk_size=1101
|
||||
parakeet_decode: starting decode with n_frames=138
|
||||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
```
|
||||
|
||||
To print segment information:
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments
|
||||
Processing audio (176000 samples, 11.00 seconds)
|
||||
Processing audio: total_frames=1101, chunk_size=1101
|
||||
parakeet_decode: starting decode with n_frames=138
|
||||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
Segments (1):
|
||||
Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country."
|
||||
Tokens [38]:
|
||||
[ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And"
|
||||
[ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so"
|
||||
[ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false ","
|
||||
[ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my"
|
||||
[ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f"
|
||||
[ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell"
|
||||
[ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow"
|
||||
[ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer"
|
||||
[ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic"
|
||||
[ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans"
|
||||
[10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false ","
|
||||
[11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a"
|
||||
[12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk"
|
||||
[13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not"
|
||||
[14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what"
|
||||
[15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your"
|
||||
[16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co"
|
||||
[17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un"
|
||||
[18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr"
|
||||
[19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y"
|
||||
[20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can"
|
||||
[21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do"
|
||||
[22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for"
|
||||
[23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you"
|
||||
[24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false ","
|
||||
[25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a"
|
||||
[26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk"
|
||||
[27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what"
|
||||
[28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you"
|
||||
[29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can"
|
||||
[30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do"
|
||||
[31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for"
|
||||
[32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your"
|
||||
[33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co"
|
||||
[34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un"
|
||||
[35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr"
|
||||
[36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y"
|
||||
[37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "."
|
||||
```
|
||||
|
||||
### Model conversion
|
||||
Clone the original model from Hugging Face:
|
||||
```console
|
||||
$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
|
||||
```
|
||||
Convert the model:
|
||||
```console
|
||||
(venv) $ python models/convert-parakeet-to-ggml.py \
|
||||
--model <path to cloned model> \
|
||||
--out-dir models \
|
||||
--out-name ggml-parakeet-tdt-0.6b-v3-f16.bin
|
||||
```
|
||||
|
||||
[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
|
||||
// command-line parameters
|
||||
struct parakeet_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
|
||||
bool use_gpu = true;
|
||||
int32_t gpu_device = 0;
|
||||
|
||||
bool print_segments = false;
|
||||
bool output_txt = false;
|
||||
bool no_prints = false;
|
||||
|
||||
std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin";
|
||||
std::string output_file = "";
|
||||
std::vector<std::string> fname_inp = {};
|
||||
};
|
||||
|
||||
static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params);
|
||||
|
||||
static char * requires_value_error(const std::string & arg) {
|
||||
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) {
|
||||
if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) {
|
||||
params.gpu_device = std::stoi(env_device);
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-"){
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-') {
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; }
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]);
|
||||
fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", "");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device);
|
||||
fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
bool * is_first = (bool *) user_data;
|
||||
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf));
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
*is_first = false;
|
||||
}
|
||||
|
||||
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
parakeet_params params;
|
||||
|
||||
if (parakeet_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.no_prints) {
|
||||
parakeet_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
|
||||
if (params.fname_inp.empty()) {
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
ctx_params.use_gpu = params.use_gpu;
|
||||
ctx_params.gpu_device = params.gpu_device;
|
||||
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str());
|
||||
}
|
||||
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Successfully loaded Parakeet model\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info());
|
||||
}
|
||||
|
||||
// Process each input file
|
||||
for (const auto & fname : params.fname_inp) {
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "\nProcessing file: %s\n", fname.c_str());
|
||||
}
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) {
|
||||
fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pcmf32.empty()) {
|
||||
fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_first = true;
|
||||
struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
full_params.n_threads = params.n_threads;
|
||||
full_params.new_token_callback = token_callback;
|
||||
full_params.new_token_callback_user_data = &is_first;
|
||||
|
||||
const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH);
|
||||
int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size());
|
||||
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
if (params.output_txt) {
|
||||
const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt";
|
||||
|
||||
std::ofstream fout(fname_out);
|
||||
if (fout.is_open()) {
|
||||
const int n_segments = parakeet_full_n_segments(pctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = parakeet_full_get_segment_text(pctx, i);
|
||||
fout << text << "\n";
|
||||
}
|
||||
fout.close();
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Output written to: %s\n", fname_out.c_str());
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.no_prints) {
|
||||
parakeet_print_timings(pctx);
|
||||
}
|
||||
|
||||
if (params.print_segments) {
|
||||
const int n_segments = parakeet_full_n_segments(pctx);
|
||||
fprintf(stderr, "\nSegments (%d):\n", n_segments);
|
||||
|
||||
for (int i = 0; i < n_segments; i++) {
|
||||
const char * text = parakeet_full_get_segment_text(pctx, i);
|
||||
const int64_t t0 = parakeet_full_get_segment_t0(pctx, i);
|
||||
const int64_t t1 = parakeet_full_get_segment_t1(pctx, i);
|
||||
const int n_tokens = parakeet_full_n_tokens(pctx, i);
|
||||
|
||||
fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
|
||||
fprintf(stderr, "Tokens [%d]:\n", n_tokens);
|
||||
|
||||
for (int j = 0; j < n_tokens; j++) {
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j);
|
||||
const char * token_str = parakeet_token_to_str(pctx, token_data.id);
|
||||
|
||||
fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n",
|
||||
j,
|
||||
token_data.id,
|
||||
token_data.frame_index,
|
||||
token_data.duration_idx,
|
||||
token_data.duration_value,
|
||||
token_data.p,
|
||||
token_data.plog,
|
||||
(long long)token_data.t0,
|
||||
(long long)token_data.t1,
|
||||
token_data.is_word_start ? "true": "false",
|
||||
token_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parakeet_free(pctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
set(TARGET parakeet-quantize)
|
||||
add_executable(${TARGET} parakeet-quantize.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct parakeet_hparams {
|
||||
int32_t n_vocab = 0;
|
||||
int32_t n_audio_ctx = 0;
|
||||
int32_t n_audio_state = 0;
|
||||
int32_t n_audio_head = 0;
|
||||
int32_t n_audio_layer = 0;
|
||||
int32_t n_mels = 0;
|
||||
int32_t ftype = 0;
|
||||
int32_t n_fft = 0;
|
||||
int32_t subsampling_factor = 0;
|
||||
int32_t n_subsampling_channels = 0;
|
||||
int32_t n_conv_kernel = 0;
|
||||
int32_t n_pred_dim = 0;
|
||||
int32_t n_pred_layers = 0;
|
||||
int32_t n_tdt_durations = 0;
|
||||
int32_t n_max_tokens = 0;
|
||||
};
|
||||
|
||||
static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
|
||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||
|
||||
auto finp = std::ifstream(fname_inp, std::ios::binary);
|
||||
if (!finp) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fout = std::ofstream(fname_out, std::ios::binary);
|
||||
if (!fout) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// magic
|
||||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
if (magic != GGML_FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__);
|
||||
return false;
|
||||
}
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
}
|
||||
|
||||
// hparams
|
||||
parakeet_hparams hparams;
|
||||
{
|
||||
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
|
||||
finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft));
|
||||
finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
|
||||
finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
|
||||
finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
|
||||
finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
|
||||
finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
|
||||
finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
|
||||
finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
|
||||
|
||||
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
|
||||
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
|
||||
|
||||
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
||||
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
||||
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
||||
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
|
||||
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
|
||||
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
|
||||
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
|
||||
|
||||
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
fout.write((char *) &ftype_dst, sizeof(ftype_dst));
|
||||
fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft));
|
||||
fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
|
||||
fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
|
||||
fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
|
||||
fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
|
||||
fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
|
||||
fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
|
||||
fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
|
||||
}
|
||||
|
||||
// mel filterbank
|
||||
{
|
||||
int32_t n_mel, n_fb;
|
||||
finp.read((char *) &n_mel, sizeof(n_mel));
|
||||
fout.write((char *) &n_mel, sizeof(n_mel));
|
||||
finp.read((char *) &n_fb, sizeof(n_fb));
|
||||
fout.write((char *) &n_fb, sizeof(n_fb));
|
||||
|
||||
const size_t n = (size_t) n_mel * n_fb;
|
||||
std::vector<float> buf(n);
|
||||
finp.read((char *) buf.data(), n * sizeof(float));
|
||||
fout.write((char *) buf.data(), n * sizeof(float));
|
||||
}
|
||||
|
||||
// window function
|
||||
{
|
||||
int32_t n_window;
|
||||
finp.read((char *) &n_window, sizeof(n_window));
|
||||
fout.write((char *) &n_window, sizeof(n_window));
|
||||
|
||||
std::vector<float> buf(n_window);
|
||||
finp.read((char *) buf.data(), n_window * sizeof(float));
|
||||
fout.write((char *) buf.data(), n_window * sizeof(float));
|
||||
}
|
||||
|
||||
// TDT durations
|
||||
{
|
||||
std::vector<uint32_t> buf(hparams.n_tdt_durations);
|
||||
finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
|
||||
fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// vocab
|
||||
{
|
||||
int32_t n_tokens;
|
||||
finp.read((char *) &n_tokens, sizeof(n_tokens));
|
||||
fout.write((char *) &n_tokens, sizeof(n_tokens));
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
int32_t len;
|
||||
finp.read((char *) &len, sizeof(len));
|
||||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
std::string token(len, '\0');
|
||||
finp.read(&token[0], len);
|
||||
fout.write(&token[0], len);
|
||||
}
|
||||
}
|
||||
|
||||
// tensors — quantize 2D weights skipping tensors that must stay F32:
|
||||
// ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights.
|
||||
// pos_bias_u / pos_bias_v are declared F32 in the loader.
|
||||
const std::vector<std::string> to_quant = { ".*" };
|
||||
std::vector<std::string> to_skip = {
|
||||
// CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights)
|
||||
"encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight",
|
||||
// Declared F32 in loader (pos_bias tensors)
|
||||
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_u",
|
||||
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_v",
|
||||
};
|
||||
|
||||
// Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant
|
||||
// types (block size 256) cannot quantize 640 evenly, so keep them F32. For
|
||||
// other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be
|
||||
// quantized normally. The loader mirrors this logic at load time.
|
||||
{
|
||||
const ggml_type qtype = ggml_ftype_to_ggml_type(ftype);
|
||||
const int32_t blck = ggml_blck_size(qtype);
|
||||
if (blck > 1 && hparams.n_pred_dim % blck != 0) {
|
||||
to_skip.push_back("decoder\\.prediction\\.embed\\.weight");
|
||||
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*");
|
||||
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*");
|
||||
to_skip.push_back("joint\\.pred\\.weight");
|
||||
to_skip.push_back("joint\\.joint_net\\.2\\.weight");
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) {
|
||||
fprintf(stderr, "%s: failed to quantize tensors\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
finp.close();
|
||||
fout.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (argc != 4) {
|
||||
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// initialise F16 lookup tables
|
||||
{
|
||||
struct ggml_init_params params = { 0, NULL, false };
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
const std::string fname_inp = argv[1];
|
||||
const std::string fname_out = argv[2];
|
||||
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
|
||||
|
||||
if (ftype == GGML_FTYPE_UNKNOWN) {
|
||||
fprintf(stderr, "%s: invalid quantization type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) {
|
||||
fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f);
|
||||
printf("%s: output model = %s\n", argv[0], fname_out.c_str());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
#ifndef PARAKEET_H
|
||||
#define PARAKEET_H
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __GNUC__
|
||||
# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
||||
#elif defined(_MSC_VER)
|
||||
# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
||||
#else
|
||||
# define PARAKEET_DEPRECATED(func, hint) func
|
||||
#endif
|
||||
|
||||
#ifdef PARAKEET_SHARED
|
||||
# ifdef _WIN32
|
||||
# ifdef PARAKEET_BUILD
|
||||
# define PARAKEET_API __declspec(dllexport)
|
||||
# else
|
||||
# define PARAKEET_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define PARAKEET_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define PARAKEET_API
|
||||
#endif
|
||||
|
||||
#define PARAKEET_SAMPLE_RATE 16000
|
||||
#define PARAKEET_HOP_LENGTH 160
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct parakeet_context;
|
||||
struct parakeet_state;
|
||||
struct parakeet_full_params;
|
||||
|
||||
typedef int32_t parakeet_pos;
|
||||
typedef int32_t parakeet_token;
|
||||
typedef int32_t parakeet_seq_id;
|
||||
|
||||
struct parakeet_context_params {
|
||||
bool use_gpu;
|
||||
int gpu_device; // CUDA device
|
||||
};
|
||||
|
||||
typedef struct parakeet_token_data {
|
||||
parakeet_token id; // the BPE subword ID (0-8191)
|
||||
|
||||
int duration_idx; // index into the models durations array
|
||||
int duration_value; // actual duration value
|
||||
int frame_index;
|
||||
|
||||
float p;
|
||||
float plog;
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
bool is_word_start;
|
||||
} parakeet_token_data;
|
||||
|
||||
typedef struct parakeet_model_loader {
|
||||
void * context;
|
||||
|
||||
size_t (*read)(void * ctx, void * output, size_t read_size);
|
||||
bool (*eof)(void * ctx);
|
||||
void (*close)(void * ctx);
|
||||
} parakeet_model_loader;
|
||||
|
||||
PARAKEET_API const char * parakeet_version(void);
|
||||
|
||||
// Various functions for loading a ggml parakeet model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params);
|
||||
|
||||
// These are the same as the above, but the internal state of the context is not allocated automatically
|
||||
// It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523)
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params);
|
||||
|
||||
PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx);
|
||||
|
||||
// Frees all allocated memory
|
||||
PARAKEET_API void parakeet_free (struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_free_state(struct parakeet_state * state);
|
||||
PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params);
|
||||
PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params);
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the default state of the provided parakeet context.
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_pcm_to_mel(
|
||||
struct parakeet_context * ctx,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
PARAKEET_API int parakeet_pcm_to_mel_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context.
|
||||
// Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 128
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_set_mel(
|
||||
struct parakeet_context * ctx,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
PARAKEET_API int parakeet_set_mel_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
// Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context.
|
||||
// Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_encode(
|
||||
struct parakeet_context * ctx,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
PARAKEET_API int parakeet_encode_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
PARAKEET_API int parakeet_tokenize(
|
||||
struct parakeet_context * ctx,
|
||||
const char * text,
|
||||
parakeet_token * tokens,
|
||||
int n_max_tokens);
|
||||
|
||||
// Return the number of tokens in the provided text
|
||||
// Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0)
|
||||
int parakeet_token_count(struct parakeet_context * ctx, const char * text);
|
||||
|
||||
PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length
|
||||
PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length
|
||||
PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx);
|
||||
|
||||
PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to parakeet_full/parakeet_chunk
|
||||
// The logits for the last token are stored in the last row
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx);
|
||||
PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token);
|
||||
|
||||
PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len);
|
||||
|
||||
// Special tokens
|
||||
PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx);
|
||||
PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx);
|
||||
PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx);
|
||||
|
||||
// Performance information from the default state.
|
||||
struct parakeet_timings {
|
||||
float sample_ms;
|
||||
float encode_ms;
|
||||
float decode_ms;
|
||||
};
|
||||
PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx);
|
||||
|
||||
// Print system information
|
||||
PARAKEET_API const char * parakeet_print_system_info(void);
|
||||
|
||||
// Available sampling strategies
|
||||
enum parakeet_sampling_strategy {
|
||||
PARAKEET_SAMPLING_GREEDY,
|
||||
};
|
||||
|
||||
// Token callback.
|
||||
// Called for each new predicted token.
|
||||
// Use the parakeet_full_...() functions to obtain the text segments
|
||||
typedef void (*parakeet_new_token_callback)(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const parakeet_token_data * token_data,
|
||||
void * user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the parakeet_full_...() functions to obtain the text segments
|
||||
typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data);
|
||||
|
||||
// Progress callback
|
||||
typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data);
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data);
|
||||
|
||||
// Parameters for the parakeet_full() function
|
||||
// If you change the order or add new parameters, make sure to update the default values in parakeet.cpp:
|
||||
// parakeet_full_default_params()
|
||||
struct parakeet_full_params {
|
||||
enum parakeet_sampling_strategy strategy;
|
||||
|
||||
int n_threads;
|
||||
int offset_ms; // start offset in ms
|
||||
int duration_ms; // audio duration to process in ms
|
||||
|
||||
bool no_context; // do not use past transcription (if any) as context
|
||||
|
||||
int audio_ctx; // overwrite the audio context size (0 = use default)
|
||||
|
||||
// called for every newly generated text segment
|
||||
parakeet_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
// called for every newly generated token
|
||||
parakeet_new_token_callback new_token_callback;
|
||||
void * new_token_callback_user_data;
|
||||
|
||||
// called on each progress update
|
||||
parakeet_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// called each time before the encoder starts
|
||||
parakeet_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
|
||||
// called each time before ggml computation starts
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_user_data;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params()
|
||||
PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void);
|
||||
PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void);
|
||||
|
||||
PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy);
|
||||
PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy);
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Not thread safe for same context
|
||||
PARAKEET_API int parakeet_full(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
PARAKEET_API int parakeet_full_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
// Process a single chunk of audio data that fits within the model's audio context window.
|
||||
// This is more efficient than parakeet_full() for short audio clips.
|
||||
PARAKEET_API int parakeet_chunk(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
// Number of generated text segments
|
||||
PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state);
|
||||
|
||||
// Get the start and end time of the specified segment
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get the text of the specified segment
|
||||
PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get number of tokens in the specified segment
|
||||
PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get the token text of the specified token in the specified segment
|
||||
PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get the token id of the specified token in the specified segment
|
||||
PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get token data for the specified token in the specified segment
|
||||
PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get the probability of the specified token in the specified segment
|
||||
PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Control logging output; default behavior is to print to stderr
|
||||
|
||||
PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
#!/usr/bin/env python3
|
||||
# Convert Parakeet TDT model from NeMo format to ggml format
|
||||
#
|
||||
# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32]
|
||||
#
|
||||
# The NeMo file is a tar archive containing:
|
||||
# - model_weights.ckpt (PyTorch checkpoint)
|
||||
# - model_config.yaml (model configuration)
|
||||
# - tokenizer files
|
||||
#
|
||||
# This script extracts the NeMo archive, loads the model weights and configuration,
|
||||
# and saves them in ggml format compatible with whisper.cpp.
|
||||
#
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import yaml
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
def hz_to_mel(freq):
|
||||
return 2595.0 * np.log10(1.0 + freq / 700.0)
|
||||
|
||||
def mel_to_hz(mel):
|
||||
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
|
||||
|
||||
def extract_nemo_archive(nemo_path, extract_dir):
|
||||
print(f"Extracting {nemo_path} to {extract_dir}")
|
||||
with tarfile.open(nemo_path, 'r') as tar:
|
||||
tar.extractall(path=extract_dir)
|
||||
print("Extraction complete")
|
||||
|
||||
def load_model_config(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def load_tokenizer(extract_dir, config):
|
||||
tokenizer_model_path = None
|
||||
tokenizer_vocab_path = None
|
||||
|
||||
for file in os.listdir(extract_dir):
|
||||
if file.endswith('_tokenizer.model'):
|
||||
tokenizer_model_path = os.path.join(extract_dir, file)
|
||||
elif file.endswith('tokenizer.vocab'):
|
||||
tokenizer_vocab_path = os.path.join(extract_dir, file)
|
||||
|
||||
if not tokenizer_model_path:
|
||||
raise FileNotFoundError("Tokenizer model file not found")
|
||||
|
||||
if not tokenizer_vocab_path:
|
||||
raise FileNotFoundError("Tokenizer vocab file not found")
|
||||
|
||||
tokens = {}
|
||||
with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
parts = line.strip().split('\t')
|
||||
if len(parts) >= 1:
|
||||
token = parts[0]
|
||||
tokens[token.encode('utf-8')] = idx
|
||||
|
||||
print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}")
|
||||
|
||||
if len(tokens) != 8192:
|
||||
print(f"WARNING: Expected 8192 tokens, got {len(tokens)}")
|
||||
|
||||
return tokens
|
||||
|
||||
def write_tensor(fout, name, data, use_f16=True, force_f32=False):
|
||||
if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1:
|
||||
data = data.reshape(1, -1, 1, 1)
|
||||
print(f" Reshaped conv bias {name} to {data.shape}")
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
ftype = 1 if use_f16 and not force_f32 else 0
|
||||
if force_f32:
|
||||
data = data.astype(np.float32)
|
||||
elif use_f16:
|
||||
if n_dims < 2 or 'bias' in name or 'norm' in name or \
|
||||
('pre_encode.conv' in name and n_dims == 4) or \
|
||||
'depthwise_conv.weight' in name:
|
||||
data = data.astype(np.float32)
|
||||
ftype = 0
|
||||
else:
|
||||
data = data.astype(np.float16)
|
||||
else:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)]
|
||||
print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}")
|
||||
name_bytes = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(name_bytes)
|
||||
|
||||
data.tofile(fout)
|
||||
|
||||
def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None):
|
||||
nemo_path = Path(nemo_path)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create temporary directory for extraction
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
extract_nemo_archive(nemo_path, temp_dir)
|
||||
|
||||
config_path = os.path.join(temp_dir, 'model_config.yaml')
|
||||
config = load_model_config(config_path)
|
||||
|
||||
print("Model configuration:")
|
||||
print(f" Sample rate: {config['sample_rate']}")
|
||||
print(f" Encoder layers: {config['encoder']['n_layers']}")
|
||||
print(f" Encoder d_model: {config['encoder']['d_model']}")
|
||||
print(f" Mel features: {config['preprocessor']['features']}")
|
||||
|
||||
weights_path = os.path.join(temp_dir, 'model_weights.ckpt')
|
||||
print(f"\nLoading model weights from {weights_path}")
|
||||
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||
|
||||
# Extract state dict
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
print(f"Loaded {len(state_dict)} tensors")
|
||||
|
||||
# Load tokenizer
|
||||
print("\nLoading tokenizer...")
|
||||
tokens = load_tokenizer(temp_dir, config)
|
||||
print(f"Loaded {len(tokens)} tokens")
|
||||
|
||||
# Prepare hyperparameters for the Parakeet ggml format.
|
||||
hparams = {
|
||||
'n_audio_ctx': 5000,
|
||||
'n_audio_state': config['encoder']['d_model'],
|
||||
'n_audio_head': config['encoder']['n_heads'],
|
||||
'n_audio_layer': config['encoder']['n_layers'],
|
||||
'n_mels': config['preprocessor']['features'],
|
||||
'n_fft': config['preprocessor']['n_fft'],
|
||||
'subsampling_factor': config['encoder']['subsampling_factor'],
|
||||
'n_subsampling_channels': config['encoder']['subsampling_conv_channels'],
|
||||
'n_conv_kernel': config['encoder']['conv_kernel_size'],
|
||||
|
||||
'n_pred_dim': config['decoder']['prednet']['pred_hidden'],
|
||||
'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'],
|
||||
'n_vocab': config['decoder']['vocab_size'],
|
||||
'n_tdt_durations': config['model_defaults']['num_tdt_durations'],
|
||||
'n_max_tokens': config['decoding']['greedy']['max_symbols'],
|
||||
}
|
||||
|
||||
print("\nGGML hyperparameters:")
|
||||
for key, value in hparams.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Create output file
|
||||
if out_name:
|
||||
fname_out = output_dir / out_name
|
||||
else:
|
||||
fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin")
|
||||
print(f"\nWriting to {fname_out}")
|
||||
|
||||
with open(fname_out, 'wb') as fout:
|
||||
# Write magic number
|
||||
fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex
|
||||
|
||||
# Write hyperparameters
|
||||
fout.write(struct.pack("i", hparams['n_vocab']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_ctx']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_state']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_head']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_layer']))
|
||||
fout.write(struct.pack("i", hparams['n_mels']))
|
||||
fout.write(struct.pack("i", 1 if use_f16 else 0))
|
||||
fout.write(struct.pack("i", hparams['n_fft']))
|
||||
fout.write(struct.pack("i", hparams['subsampling_factor']))
|
||||
fout.write(struct.pack("i", hparams['n_subsampling_channels']))
|
||||
fout.write(struct.pack("i", hparams['n_conv_kernel']))
|
||||
fout.write(struct.pack("i", hparams['n_pred_dim']))
|
||||
fout.write(struct.pack("i", hparams['n_pred_layers']))
|
||||
fout.write(struct.pack("i", hparams['n_tdt_durations']))
|
||||
fout.write(struct.pack("i", hparams['n_max_tokens']))
|
||||
|
||||
# Extract mel filterbank from model
|
||||
fb_key = None
|
||||
for key in state_dict.keys():
|
||||
if 'featurizer.fb' in key or 'filterbank' in key.lower():
|
||||
fb_key = key
|
||||
break
|
||||
|
||||
if not fb_key:
|
||||
print("\nERROR: Mel filterbank not found in model!")
|
||||
print("Expected tensor with 'featurizer.fb' or 'filterbank' in name")
|
||||
print("\nAvailable preprocessor tensors:")
|
||||
for key in sorted(state_dict.keys()):
|
||||
if 'preprocessor' in key or 'featurizer' in key:
|
||||
print(f" {key}: {state_dict[key].shape}")
|
||||
raise ValueError("Mel filterbank tensor not found in model")
|
||||
|
||||
print(f"\nUsing model's mel filterbank from: {fb_key}")
|
||||
mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32)
|
||||
print(f" Filterbank shape: {mel_filters.shape}")
|
||||
print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}")
|
||||
print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}")
|
||||
print(f" First row sum: {mel_filters[0].sum():.6f}")
|
||||
|
||||
if len(mel_filters.shape) != 2:
|
||||
raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}")
|
||||
|
||||
n_mels, n_freqs = mel_filters.shape
|
||||
fout.write(struct.pack("i", n_mels)) # n_mel
|
||||
fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins)
|
||||
|
||||
# Write mel filterbank
|
||||
for i in range(n_mels):
|
||||
for j in range(n_freqs):
|
||||
fout.write(struct.pack("f", mel_filters[i, j]))
|
||||
|
||||
# Extract window function from model
|
||||
window_key = None
|
||||
for key in state_dict.keys():
|
||||
if 'featurizer.window' in key or 'preproc' in key and 'window' in key:
|
||||
window_key = key
|
||||
break
|
||||
|
||||
if not window_key:
|
||||
print("\nERROR: Window function not found in model!")
|
||||
print("Expected tensor with 'featurizer.window' in name")
|
||||
raise ValueError("Window function tensor not found in model")
|
||||
|
||||
print(f"\nUsing model's window function from: {window_key}")
|
||||
window = state_dict[window_key].squeeze().numpy().astype(np.float32)
|
||||
print(f" Window shape: {window.shape}")
|
||||
print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}")
|
||||
print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}")
|
||||
print(f" Window sum: {window.sum():.6f}")
|
||||
|
||||
if len(window.shape) != 1:
|
||||
raise ValueError(f"Expected 1D window, got shape {window.shape}")
|
||||
|
||||
n_window = window.shape[0]
|
||||
fout.write(struct.pack("i", n_window))
|
||||
|
||||
# Write window function
|
||||
for i in range(n_window):
|
||||
fout.write(struct.pack("f", window[i]))
|
||||
|
||||
# Write TDT durations
|
||||
tdt_durations = config['model_defaults']['tdt_durations']
|
||||
if len(tdt_durations) != hparams['n_tdt_durations']:
|
||||
raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}")
|
||||
|
||||
for duration in tdt_durations:
|
||||
fout.write(struct.pack("I", duration))
|
||||
|
||||
fout.write(struct.pack("i", len(tokens)))
|
||||
for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]):
|
||||
fout.write(struct.pack("i", len(token_bytes)))
|
||||
fout.write(token_bytes)
|
||||
|
||||
# Pre-collect prediction LSTM input-hidden biases so they can be
|
||||
# folded into the hidden-hidden bias during the main write loop.
|
||||
lstm_prefix = 'decoder.prediction.dec_rnn.lstm'
|
||||
pred_bias_ih = {}
|
||||
for key, t in state_dict.items():
|
||||
if f'{lstm_prefix}.bias_ih_l' in key:
|
||||
layer_idx = int(key.rsplit('bias_ih_l', 1)[1])
|
||||
pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32)
|
||||
|
||||
print("\nConverting model weights...")
|
||||
for name, tensor in state_dict.items():
|
||||
# Skip the filterbank and window - already written in preprocessing section
|
||||
if name == fb_key:
|
||||
continue
|
||||
if name == window_key:
|
||||
continue
|
||||
|
||||
# bias_ih is folded into bias_hh below; skip writing it separately
|
||||
if f'{lstm_prefix}.bias_ih_l' in name:
|
||||
continue
|
||||
|
||||
# Don't squeeze Conv2d weights - they need to preserve all 4 dimensions
|
||||
if 'conv' in name and 'weight' in name and len(tensor.shape) == 4:
|
||||
data = tensor.numpy()
|
||||
else:
|
||||
data = tensor.squeeze().numpy()
|
||||
|
||||
# For prediction LSTM weights/biases:
|
||||
# Fold bias_ih into bias_hh (bias_ih already skipped above).
|
||||
# Reorder gates (input, forget, cell, output) from PyTorch layout
|
||||
# [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs
|
||||
# (i, f, o) are contiguous.
|
||||
if name.startswith(f'{lstm_prefix}.'):
|
||||
if f'{lstm_prefix}.bias_hh_l' in name:
|
||||
layer_idx = int(name.rsplit('bias_hh_l', 1)[1])
|
||||
data = data.astype(np.float32) + pred_bias_ih[layer_idx]
|
||||
name = name.replace('bias_hh_l', 'bias_h_l')
|
||||
h = data.shape[0] // 4
|
||||
data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
|
||||
|
||||
write_tensor(fout, name, data, use_f16=use_f16)
|
||||
|
||||
print(f"\nConversion complete!")
|
||||
print(f"Output file: {fname_out}")
|
||||
print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert Parakeet TDT model from NeMo format to ggml format'
|
||||
)
|
||||
parser.add_argument('--model', type=str, required=True,
|
||||
help='Path to Parakeet .nemo model file')
|
||||
parser.add_argument('--out-dir', type=str, required=True,
|
||||
help='Directory to write ggml model file')
|
||||
parser.add_argument('--use-f32', action='store_true', default=False,
|
||||
help='Use f32 instead of f16 (default: f16)')
|
||||
parser.add_argument('--out-name', type=str, default=None,
|
||||
help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model):
|
||||
print(f"Error: {args.model} not found")
|
||||
sys.exit(1)
|
||||
|
||||
use_f16 = not args.use_f32
|
||||
convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name)
|
||||
Binary file not shown.
|
|
@ -0,0 +1,182 @@
|
|||
#!/usr/bin/env python3
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
def write_tensor(fout, name, data):
|
||||
n_dims = len(data.shape)
|
||||
data = data.astype(np.float32)
|
||||
ftype = 0 # GGML_TYPE_F32
|
||||
|
||||
name_bytes = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(name_bytes)
|
||||
data.tofile(fout)
|
||||
|
||||
def generate(output_path):
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
hparams = {
|
||||
'n_vocab': 10,
|
||||
'n_audio_ctx': 3200,
|
||||
'n_audio_state': 8,
|
||||
'n_audio_head': 2,
|
||||
'n_audio_layer': 1,
|
||||
'n_mels': 16,
|
||||
'ftype': 0,
|
||||
'n_fft': 64,
|
||||
'subsampling_factor': 8,
|
||||
'n_subsampling_channels': 4,
|
||||
'n_conv_kernel': 3,
|
||||
'n_pred_dim': 8,
|
||||
'n_pred_layers': 1,
|
||||
'n_tdt_durations': 2,
|
||||
'n_max_tokens': 5,
|
||||
}
|
||||
|
||||
n_vocab = hparams['n_vocab']
|
||||
n_state = hparams['n_audio_state']
|
||||
n_head = hparams['n_audio_head']
|
||||
n_layer = hparams['n_audio_layer']
|
||||
n_mels = hparams['n_mels']
|
||||
n_fft = hparams['n_fft']
|
||||
n_sub_fac = hparams['subsampling_factor']
|
||||
n_sub_ch = hparams['n_subsampling_channels']
|
||||
n_conv_ker = hparams['n_conv_kernel']
|
||||
dec_dim = hparams['n_pred_dim']
|
||||
n_pred_l = hparams['n_pred_layers']
|
||||
n_tdt = hparams['n_tdt_durations']
|
||||
|
||||
n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch
|
||||
n_head_dim = n_state // n_head
|
||||
n_pred_embed = n_vocab + 1
|
||||
n_lstm_gates = 4 * dec_dim
|
||||
n_joint_out = n_vocab + n_tdt + 1
|
||||
n_freqs = n_fft // 2 + 1
|
||||
|
||||
def f32(*shape):
|
||||
return rng.standard_normal(shape).astype(np.float32)
|
||||
|
||||
with open(output_path, 'wb') as fout:
|
||||
fout.write(struct.pack("I", 0x67676d6c))
|
||||
|
||||
for key in ['n_vocab',
|
||||
'n_audio_ctx',
|
||||
'n_audio_state',
|
||||
'n_audio_head',
|
||||
'n_audio_layer',
|
||||
'n_mels',
|
||||
'ftype',
|
||||
'n_fft',
|
||||
'subsampling_factor',
|
||||
'n_subsampling_channels',
|
||||
'n_conv_kernel',
|
||||
'n_pred_dim',
|
||||
'n_pred_layers',
|
||||
'n_tdt_durations',
|
||||
'n_max_tokens']:
|
||||
fout.write(struct.pack("i", hparams[key]))
|
||||
|
||||
fout.write(struct.pack("i", n_mels))
|
||||
fout.write(struct.pack("i", n_freqs))
|
||||
f32(n_mels, n_freqs).tofile(fout)
|
||||
|
||||
fout.write(struct.pack("i", n_fft))
|
||||
f32(n_fft).tofile(fout)
|
||||
|
||||
for d in range(n_tdt):
|
||||
fout.write(struct.pack("I", d))
|
||||
|
||||
tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)]
|
||||
assert len(tokens) == n_vocab
|
||||
fout.write(struct.pack("i", n_vocab))
|
||||
for tok in tokens:
|
||||
tok_bytes = tok.encode('utf-8')
|
||||
fout.write(struct.pack("i", len(tok_bytes)))
|
||||
fout.write(tok_bytes)
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc))
|
||||
write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
for i in range(n_layer):
|
||||
p = f"encoder.layers.{i}"
|
||||
|
||||
write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state)))
|
||||
num_batches = np.zeros(1, dtype=np.int32)
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches)
|
||||
write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim))
|
||||
write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_out.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_out.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim))
|
||||
|
||||
def reorder_gates(data):
|
||||
h = data.shape[0] // 4
|
||||
return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
|
||||
|
||||
for i in range(n_pred_l):
|
||||
base = f"decoder.prediction.dec_rnn.lstm"
|
||||
write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
|
||||
write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
|
||||
write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates)))
|
||||
|
||||
write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim))
|
||||
write_tensor(fout, "joint.pred.bias", f32(dec_dim))
|
||||
write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state))
|
||||
write_tensor(fout, "joint.enc.bias", f32(dec_dim))
|
||||
write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim))
|
||||
write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out))
|
||||
|
||||
size = Path(output_path).stat().st_size
|
||||
print(f"Generated {output_path} ({size / 1024:.1f} KB)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin'
|
||||
generate(output)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
torch
|
||||
numpy
|
||||
pyyaml
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
build_dir=build
|
||||
modelname=ggml-parakeet-tdt-0.6b-v3
|
||||
model=models/${modelname}-f32.bin
|
||||
cmd=parakeet-quantize
|
||||
|
||||
cmake --build ${build_dir} --target $cmd -j 12
|
||||
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
import argparse
|
||||
import os
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
|
||||
USER_NAME = "ggml-org"
|
||||
REPO_ID = f"{USER_NAME}/parakeet-GGUF"
|
||||
|
||||
MODELS = {
|
||||
"f32": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin",
|
||||
"description": "Full precision (F32)",
|
||||
},
|
||||
"f16": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin",
|
||||
"description": "Half precision (F16)",
|
||||
},
|
||||
"q8_0": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
|
||||
"description": "8-bit quantized (Q8_0)",
|
||||
},
|
||||
"q4_0": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
|
||||
"description": "4-bit quantized (Q4_0)",
|
||||
},
|
||||
"q4_k": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
|
||||
"description": "4-bit K-quantized (Q4_k)",
|
||||
},
|
||||
}
|
||||
|
||||
def build_model_card(uploaded_variants):
|
||||
lines = [
|
||||
f"---",
|
||||
f"license: mit",
|
||||
f"base_model: nvidia/parakeet-tdt-0.6b-v3",
|
||||
f"tags:",
|
||||
f"- gguf",
|
||||
f"- asr",
|
||||
f"---",
|
||||
f"",
|
||||
f"# Parakeet TDT 0.6B v3 (GGUF)",
|
||||
f"",
|
||||
f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).",
|
||||
f"",
|
||||
f"## Available files",
|
||||
f"",
|
||||
]
|
||||
|
||||
for key, m in MODELS.items():
|
||||
if key in uploaded_variants:
|
||||
lines.append(f"- `{m['remote_name']}` — {m['description']}")
|
||||
|
||||
lines += [
|
||||
f"",
|
||||
f"## Usage",
|
||||
f"",
|
||||
f"Build parakeet-cli:",
|
||||
f"```console",
|
||||
f"git clone https://github.com/ggml-org/whisper.cpp.git",
|
||||
f"cd whisper.cpp",
|
||||
f"cmake -B build -S .",
|
||||
f"cmake --build build --target parakeet-cli -j $(nproc)",
|
||||
f"```",
|
||||
f"",
|
||||
f"Download a model (e.g. Q8_0):",
|
||||
f"```console",
|
||||
f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models",
|
||||
f"```",
|
||||
f"",
|
||||
f"Run:",
|
||||
f"```console",
|
||||
f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav",
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def upload_variant(api, key):
|
||||
m = MODELS[key]
|
||||
local_path = m["local_path"]
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
print(f" Skipping {key}: {local_path} not found")
|
||||
return False
|
||||
|
||||
print(f" Uploading {m['remote_name']} ({m['description']})...")
|
||||
api.upload_file(
|
||||
path_or_fileobj=local_path,
|
||||
path_in_repo=m["remote_name"],
|
||||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
commit_message=f"Upload {m['remote_name']}",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face")
|
||||
parser.add_argument(
|
||||
"variants",
|
||||
nargs="*",
|
||||
default=None,
|
||||
metavar="{" + ",".join(MODELS.keys()) + "}",
|
||||
help="Model variants to upload (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-model-card",
|
||||
action="store_true",
|
||||
help="Skip updating the model card README",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
api = HfApi()
|
||||
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
||||
|
||||
variants = args.variants if args.variants else list(MODELS.keys())
|
||||
|
||||
unknown = [v for v in variants if v not in MODELS]
|
||||
if unknown:
|
||||
parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})")
|
||||
|
||||
uploaded = []
|
||||
for key in variants:
|
||||
if upload_variant(api, key):
|
||||
uploaded.append(key)
|
||||
|
||||
if not uploaded:
|
||||
print("No models were uploaded.")
|
||||
return
|
||||
|
||||
if not args.no_model_card:
|
||||
print("Updating model card...")
|
||||
existing = [k for k in MODELS if k in uploaded or
|
||||
any(f.rfilename == MODELS[k]["remote_name"]
|
||||
for f in api.list_repo_files(REPO_ID, repo_type="model")
|
||||
if hasattr(f, "rfilename"))]
|
||||
card = build_model_card(existing if existing else uploaded)
|
||||
api.upload_file(
|
||||
path_or_fileobj=card.encode(),
|
||||
path_in_repo="README.md",
|
||||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
commit_message="Update README.md",
|
||||
)
|
||||
|
||||
print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -109,23 +109,43 @@ add_library(whisper
|
|||
whisper.cpp
|
||||
)
|
||||
|
||||
add_library(parakeet
|
||||
../include/parakeet.h
|
||||
parakeet-arch.h
|
||||
parakeet.cpp
|
||||
)
|
||||
|
||||
target_include_directories(parakeet PUBLIC . ../include)
|
||||
target_compile_features (parakeet PUBLIC cxx_std_11)
|
||||
target_link_libraries(parakeet PUBLIC ggml Threads::Threads)
|
||||
|
||||
# Set the version numbers
|
||||
set_target_properties(whisper PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
SOVERSION ${SOVERSION}
|
||||
)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
SOVERSION ${SOVERSION}
|
||||
)
|
||||
|
||||
target_include_directories(whisper PUBLIC . ../include)
|
||||
target_compile_features (whisper PUBLIC cxx_std_11) # don't bump
|
||||
|
||||
if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN")
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN)
|
||||
set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN)
|
||||
endif()
|
||||
|
||||
if (WHISPER_EXTRA_FLAGS)
|
||||
target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS})
|
||||
endif()
|
||||
|
||||
if (PARAKEET_EXTRA_FLAGS)
|
||||
target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS})
|
||||
endif()
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
target_link_libraries(whisper PUBLIC ggml Threads::Threads)
|
||||
|
||||
|
|
@ -144,4 +164,7 @@ endif()
|
|||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,188 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
enum parakeet_tensor {
|
||||
// Encoder pre_encode
|
||||
PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_OUT_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS,
|
||||
|
||||
// Encoder layers (per-layer)
|
||||
PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF1_BIAS,
|
||||
PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_CONV_BIAS,
|
||||
PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_BIAS,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_MEAN,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_VAR,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES,
|
||||
PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V,
|
||||
PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF2_BIAS,
|
||||
PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_OUT_BIAS,
|
||||
|
||||
// Prediction network
|
||||
PARAKEET_TENSOR_PRED_EMBED_WEIGHT,
|
||||
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH,
|
||||
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH,
|
||||
PARAKEET_TENSOR_PRED_LSTM_BIAS_H,
|
||||
|
||||
// Joint network
|
||||
PARAKEET_TENSOR_JOINT_PRED_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_PRED_BIAS,
|
||||
PARAKEET_TENSOR_JOINT_ENC_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_ENC_BIAS,
|
||||
PARAKEET_TENSOR_JOINT_NET_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_NET_BIAS,
|
||||
};
|
||||
|
||||
static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = {
|
||||
// Encoder pre_encode
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"},
|
||||
|
||||
// Encoder layers (use %d for layer number)
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"},
|
||||
|
||||
// Prediction network
|
||||
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"},
|
||||
|
||||
// Joint network
|
||||
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"},
|
||||
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"},
|
||||
};
|
||||
|
||||
static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = {
|
||||
// Encoder pre_encode
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD},
|
||||
|
||||
// Encoder layers
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD},
|
||||
|
||||
// Prediction network
|
||||
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD},
|
||||
|
||||
// Joint network
|
||||
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD},
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -118,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE
|
|||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
|
||||
add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
|
||||
set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en")
|
||||
|
||||
# Parakeet model loading test
|
||||
set(PARAKEET_TEST test-parakeet)
|
||||
add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp)
|
||||
target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common)
|
||||
target_compile_definitions(${PARAKEET_TEST} PRIVATE
|
||||
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin"
|
||||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
|
||||
add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST})
|
||||
set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh")
|
||||
|
||||
# The following parakeet test require a real ggml-parakeet-tdt model to have
|
||||
# been converted or downloaded:
|
||||
# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models
|
||||
#
|
||||
# And also required more audio samples that are shipped by default. These can
|
||||
# downloaded by running:
|
||||
# $ make samples
|
||||
function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH)
|
||||
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0")
|
||||
if (ARGC GREATER 4)
|
||||
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}")
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCE})
|
||||
target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE parakeet common)
|
||||
target_compile_definitions(${TEST_TARGET} PRIVATE
|
||||
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin"
|
||||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}"
|
||||
EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}"
|
||||
TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD})
|
||||
|
||||
add_custom_target(run-${TEST_TARGET}
|
||||
COMMAND $<TARGET_FILE:${TEST_TARGET}>
|
||||
DEPENDS ${TEST_TARGET}
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
|
||||
endfunction()
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-jfk
|
||||
test-parakeet-full.cpp
|
||||
samples/jfk.wav
|
||||
tests/parakeet-expected-jfk-output.txt)
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-gb1
|
||||
test-parakeet-full.cpp
|
||||
samples/gb1.wav
|
||||
tests/parakeet-expected-gb1-output.txt)
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-diffusion
|
||||
test-parakeet-full.cpp
|
||||
samples/diffusion2023-07-03.flac
|
||||
tests/parakeet-expected-diffusion-output.txt
|
||||
0.95)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
__pycache__
|
||||
*.tar.gz
|
||||
*.txt
|
||||
eval.conf
|
||||
venv
|
||||
LibriSpeech
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz
|
||||
|
||||
all: eval
|
||||
|
||||
eval:
|
||||
$(MAKE) -f eval.mk
|
||||
|
||||
clean:
|
||||
$(MAKE) -f eval.mk clean
|
||||
|
||||
get-audio:
|
||||
wget -c $(TAR_URL)
|
||||
tar -xf test-clean.tar.gz
|
||||
|
||||
.PHONY: all eval clean setup-venv clean-venv get-audio
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# parakeet.cpp/tests/librispeech
|
||||
|
||||
[LibriSpeech](https://www.openslr.org/12) is a standard dataset for
|
||||
training and evaluating automatic speech recognition systems.
|
||||
|
||||
This directory contains a set of tools to evaluate the recognition
|
||||
performance of parakeet.cpp on LibriSpeech corpus.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet
|
||||
model in `ggml` format.
|
||||
|
||||
```
|
||||
$ # Execute the commands below in the project root dir.
|
||||
$ cmake -B build
|
||||
$ cmake --build build --config Release
|
||||
```
|
||||
|
||||
2. Download the audio files from LibriSpeech project.
|
||||
|
||||
```
|
||||
$ make get-audio
|
||||
```
|
||||
|
||||
3. Set up the environment to compute WER score.
|
||||
|
||||
```
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
For example, if you use `virtualenv`, you can set up it as follows:
|
||||
|
||||
```
|
||||
$ python3 -m venv venv
|
||||
$ . venv/bin/activate
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Run the benchmark test.
|
||||
|
||||
```
|
||||
$ make
|
||||
```
|
||||
|
||||
## How-to guides
|
||||
|
||||
### How to change the inference parameters
|
||||
|
||||
Create `eval.conf` and override variables.
|
||||
|
||||
```
|
||||
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
|
||||
PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt
|
||||
```
|
||||
|
||||
Check out `eval.mk` for more details.
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
PYTHON = python
|
||||
|
||||
PARAKEET_PREFIX = ../../
|
||||
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
|
||||
|
||||
PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli
|
||||
PARAKEET_FLAGS = --no-prints --output-txt
|
||||
|
||||
# You can create eval.conf to override the PARAKEET_* variables
|
||||
# defined above.
|
||||
-include eval.conf
|
||||
|
||||
# This follows the file structure of the LibriSpeech project.
|
||||
AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac))
|
||||
TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS))
|
||||
|
||||
# We output the evaluation result to this file.
|
||||
DONE = $(PARAKEET_MODEL).txt
|
||||
|
||||
all: $(DONE)
|
||||
|
||||
$(DONE): $(TRANS_TXTS)
|
||||
$(PYTHON) eval.py > $@.tmp
|
||||
mv $@.tmp $@
|
||||
|
||||
# Note: This task writes to a temporary file first to
|
||||
# create the target file atomically.
|
||||
%.flac.txt: %.flac
|
||||
$(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp
|
||||
mv $^.tmp.txt $^.txt
|
||||
|
||||
archive:
|
||||
tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE)
|
||||
|
||||
clean:
|
||||
@rm -f $(TRANS_TXTS)
|
||||
@rm -f $(DONE)
|
||||
|
||||
.PHONY: all clean
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import glob
|
||||
import jiwer
|
||||
from normalizers import EnglishTextNormalizer
|
||||
|
||||
def get_reference():
|
||||
ref = {}
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'):
|
||||
with open(path) as fp:
|
||||
for line in fp:
|
||||
code, text = line.strip().split(" ", maxsplit=1)
|
||||
ref [code] = text
|
||||
return ref
|
||||
|
||||
def get_hypothesis():
|
||||
hyp = {}
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'):
|
||||
with open(path) as fp:
|
||||
text = fp.read().strip()
|
||||
code = os.path.basename(path).replace('.flac.txt', '')
|
||||
hyp[code] = text
|
||||
return hyp
|
||||
|
||||
def get_codes():
|
||||
codes = []
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.flac'):
|
||||
codes.append(os.path.basename(path).replace('.flac', ''))
|
||||
return sorted(codes)
|
||||
|
||||
def main():
|
||||
normalizer = EnglishTextNormalizer()
|
||||
|
||||
ref_orig = get_reference()
|
||||
hyp_orig = get_hypothesis()
|
||||
|
||||
ref_clean = []
|
||||
hyp_clean = []
|
||||
|
||||
for code in get_codes():
|
||||
ref_clean.append(normalizer(ref_orig[code]))
|
||||
hyp_clean.append(normalizer(hyp_orig[code]))
|
||||
|
||||
wer = jiwer.wer(ref_clean, hyp_clean)
|
||||
print(f"WER: {wer * 100:.2f}%")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
Code in this directory is adapted from OpenAI Whisper project
|
||||
(https://github.com/openai/whisper) and carries the following
|
||||
copyright and license.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
(
|
||||
c
|
||||
if c in keep
|
||||
else (
|
||||
ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else (
|
||||
""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||
)
|
||||
)
|
||||
)
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,550 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1 @@
|
|||
My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America.
|
||||
|
|
@ -0,0 +1 @@
|
|||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD
|
||||
#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0
|
||||
#endif
|
||||
|
||||
static std::string read_expected_transcription(const char * path) {
|
||||
std::ifstream fin(path);
|
||||
assert(fin.is_open());
|
||||
|
||||
std::string text(
|
||||
(std::istreambuf_iterator<char>(fin)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) {
|
||||
text.pop_back();
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::vector<std::string> transcription_words(const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
std::string word;
|
||||
|
||||
for (unsigned char ch : text) {
|
||||
if (std::isalnum(ch)) {
|
||||
word.push_back((char) std::tolower(ch));
|
||||
} else if (!word.empty()) {
|
||||
words.push_back(word);
|
||||
word.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!word.empty()) {
|
||||
words.push_back(word);
|
||||
}
|
||||
|
||||
return words;
|
||||
}
|
||||
|
||||
static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) {
|
||||
const std::vector<std::string> expected_words = transcription_words(expected);
|
||||
const std::vector<std::string> actual_words = transcription_words(actual);
|
||||
|
||||
if (expected_words.empty() && actual_words.empty()) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
if (expected_words.empty() || actual_words.empty()) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
std::vector<int> prev(actual_words.size() + 1, 0);
|
||||
std::vector<int> cur (actual_words.size() + 1, 0);
|
||||
|
||||
for (size_t i = 0; i < expected_words.size(); ++i) {
|
||||
std::fill(cur.begin(), cur.end(), 0);
|
||||
|
||||
for (size_t j = 0; j < actual_words.size(); ++j) {
|
||||
if (expected_words[i] == actual_words[j]) {
|
||||
cur[j + 1] = prev[j] + 1;
|
||||
} else {
|
||||
cur[j + 1] = std::max(prev[j + 1], cur[j]);
|
||||
}
|
||||
}
|
||||
|
||||
prev.swap(cur);
|
||||
}
|
||||
|
||||
const int lcs = prev[actual_words.size()];
|
||||
return (2.0 * lcs) / (expected_words.size() + actual_words.size());
|
||||
}
|
||||
|
||||
static bool verify_transcription(const std::string & expected, const std::string & actual) {
|
||||
const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD;
|
||||
|
||||
if (threshold >= 1.0) {
|
||||
if (actual == expected) {
|
||||
return true;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "[Failed] Transcript mismatched\n");
|
||||
fprintf(stderr, "expected:\n%s\n\n", expected.c_str());
|
||||
fprintf(stderr, "actual:\n%s\n", actual.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
const double similarity = transcription_lcs_similarity(expected, actual);
|
||||
printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold);
|
||||
|
||||
if (similarity >= threshold) {
|
||||
return true;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold);
|
||||
fprintf(stderr, "Expected:\n%s\n\n", expected.c_str());
|
||||
fprintf(stderr, "Actual:\n%s\n", actual.c_str());
|
||||
return false;
|
||||
}
|
||||
|
|
@ -21,13 +21,21 @@ cd `dirname $0`
|
|||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" )
|
||||
|
||||
# Parakeet model variants
|
||||
parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
printf "\n"
|
||||
printf " Available models:"
|
||||
printf " Available whisper models:"
|
||||
for model in "${models[@]}"; do
|
||||
printf " $model"
|
||||
done
|
||||
printf "\n"
|
||||
printf " Available parakeet models:"
|
||||
for model in "${parakeet_models[@]}"; do
|
||||
printf " parakeet-$model"
|
||||
done
|
||||
printf "\n\n"
|
||||
}
|
||||
|
||||
|
|
@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then
|
|||
fi
|
||||
|
||||
model=$1
|
||||
main="../build/bin/whisper-cli"
|
||||
|
||||
threads=""
|
||||
if [ $# -eq 2 ]; then
|
||||
threads="-t $2"
|
||||
fi
|
||||
|
||||
if [ ! -f ../models/ggml-$model.bin ]; then
|
||||
printf "Model $model not found. Aborting\n"
|
||||
# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32")
|
||||
is_parakeet=0
|
||||
parakeet_variant=""
|
||||
if [[ $model == parakeet-* ]]; then
|
||||
is_parakeet=1
|
||||
parakeet_variant="${model#parakeet-}"
|
||||
fi
|
||||
for v in "${parakeet_models[@]}"; do
|
||||
if [[ $model == "$v" ]]; then
|
||||
is_parakeet=1
|
||||
parakeet_variant="$v"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $is_parakeet -eq 1 ]; then
|
||||
main="../build/bin/parakeet-cli"
|
||||
model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin"
|
||||
else
|
||||
main="../build/bin/whisper-cli"
|
||||
model_path="../models/ggml-${model}.bin"
|
||||
fi
|
||||
|
||||
if [ ! -f $model_path ]; then
|
||||
printf "Model $model not found ($model_path). Aborting\n"
|
||||
list_models
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -110,7 +140,11 @@ function run_lang() {
|
|||
fi
|
||||
fi
|
||||
|
||||
$main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null
|
||||
if [ $is_parakeet -eq 1 ]; then
|
||||
$main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null
|
||||
else
|
||||
$main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null
|
||||
fi
|
||||
|
||||
git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt
|
||||
|
||||
|
|
@ -120,7 +154,7 @@ function run_lang() {
|
|||
|
||||
run_lang "en" "${urls_en[@]}"
|
||||
|
||||
if [[ $model != *.en* ]]; then
|
||||
if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then
|
||||
run_lang "es" "${urls_es[@]}"
|
||||
run_lang "it" "${urls_it[@]}"
|
||||
run_lang "pt" "${urls_pt[@]}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
#include "parakeet-verification.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
#include <cassert>
|
||||
|
||||
struct test_state {
|
||||
bool is_first = true;
|
||||
std::string transcript;
|
||||
};
|
||||
|
||||
void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
}
|
||||
|
||||
bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool abort_callback(void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
return false; // just continue without aborting.
|
||||
}
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
test_state * tstate = static_cast<test_state *>(user_data);
|
||||
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf));
|
||||
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
tstate->transcript += text_buf;
|
||||
tstate->is_first = false;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::string model_path = PARAKEET_MODEL_PATH;
|
||||
std::string sample_path = SAMPLE_PATH;
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
assert(pcmf32.size() > 0);
|
||||
assert(pcmf32s.size() == 0); // no stereo vector
|
||||
|
||||
printf("Loading Parakeet model from: %s\n", model_path.c_str());
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "Failed to load Parakeet model\n");
|
||||
return 1;
|
||||
}
|
||||
printf("Successfully loaded Parakeet model\n");
|
||||
|
||||
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
test_state tstate;
|
||||
params.new_token_callback = token_callback;
|
||||
params.new_token_callback_user_data = &tstate;
|
||||
bool progress_callback_called = false;
|
||||
params.progress_callback = progress_callback;
|
||||
params.progress_callback_user_data = &progress_callback_called;
|
||||
bool encoder_begin_callback_called = false;
|
||||
params.encoder_begin_callback = encoder_begin_callback;
|
||||
params.encoder_begin_callback_user_data = &encoder_begin_callback_called;
|
||||
bool abort_callback_called = false;
|
||||
params.abort_callback = abort_callback;
|
||||
params.abort_callback_user_data = &abort_callback_called;
|
||||
|
||||
int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size());
|
||||
assert(ret == 0);
|
||||
assert(progress_callback_called);
|
||||
assert(encoder_begin_callback_called);
|
||||
assert(abort_callback_called);
|
||||
|
||||
const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH);
|
||||
const bool transcript_matches = verify_transcription(expected, tstate.transcript);
|
||||
|
||||
parakeet_free(pctx);
|
||||
|
||||
if (!transcript_matches) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("\nTest passed: parakeet_full succeeded!\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
#include <cassert>
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
static bool is_first = true;
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf));
|
||||
|
||||
int32_t time_ms = token_data->frame_index * 10;
|
||||
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
is_first = false;
|
||||
}
|
||||
|
||||
void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) {
|
||||
const int n_segments = parakeet_full_n_segments_from_state(state);
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
printf("\nSegment Callback: %d new segment(s)\n", n_new);
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
const char * text = parakeet_full_get_segment_text_from_state(state, i);
|
||||
const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i);
|
||||
const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i);
|
||||
|
||||
printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
|
||||
printf("Tokens:\n");
|
||||
|
||||
const int n_tokens = parakeet_full_n_tokens_from_state(state, i);
|
||||
for (int j = 0; j < n_tokens; j++) {
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j);
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data.id);
|
||||
|
||||
printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n",
|
||||
j,
|
||||
token_data.id,
|
||||
token_data.frame_index,
|
||||
token_data.duration_idx,
|
||||
token_data.duration_value,
|
||||
token_data.p,
|
||||
token_data.plog,
|
||||
(long long)token_data.t0,
|
||||
(long long)token_data.t1,
|
||||
token_data.is_word_start,
|
||||
token_str);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::string model_path = PARAKEET_MODEL_PATH;
|
||||
std::string sample_path = SAMPLE_PATH;
|
||||
|
||||
// Load the sample audio file
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
assert(pcmf32.size() > 0);
|
||||
assert(pcmf32s.size() == 0);
|
||||
|
||||
printf("Loading Parakeet model from: %s\n", model_path.c_str());
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "Failed to load Parakeet model\n");
|
||||
return 1;
|
||||
}
|
||||
printf("Successfully loaded Parakeet model\n");
|
||||
|
||||
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
params.new_token_callback = token_callback;
|
||||
params.new_token_callback_user_data = nullptr;
|
||||
params.new_segment_callback = segment_callback;
|
||||
params.new_segment_callback_user_data = nullptr;
|
||||
parakeet_state * state = parakeet_init_state(pctx);
|
||||
|
||||
int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size());
|
||||
assert(ret == 0);
|
||||
|
||||
parakeet_free_state(state);
|
||||
parakeet_free(pctx);
|
||||
|
||||
printf("\nTest passed: Parakeet model loaded and freed successfully\n");
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue