Compare commits

...

6 Commits

Author SHA1 Message Date
Daniel Bevenius f049fff95a
release : v1.9.1 (#3892) 2026-06-19 06:12:37 +02:00
Daniel Bevenius 200b119790
ci : add GGML_NATIVE=OFF and GGML_BMI2=OFF to windows-blas (#3891)
* ci : add GGML_NATIVE=OFF and build all cpu-variants

This commit adds -DGGML_BACKEND_DL=ON, -DGGML_NATIVE=OFF, and
-DGGML_CPU_ALL_VARIANTS=ON to the releases.

The motivation for this is that currently the Windows BLAS build
uses the native CPU instructions and if target systems do not support
these instructions, the build will fail like the linked issue reports.

Resolves: https://github.com/ggml-org/whisper.cpp/issues/3889

* ci : update ubuntu-cpu release job for all variants [no ci]

This commit enables the ubuntu-cpu job to include all cpu variants and
ensures that the ggml backend libraries are built into the bin directory
similar to how llama.cpp does it.

The following is a build on my fork with this change:
https://github.com/danbev/whisper.cpp/releases/tag/untagged-fc3c71f0bf0f7bf19d19
2026-06-18 14:49:08 +02:00
Daniel Bevenius 86c40c3bd6
release : v1.9.0 (#3886) 2026-06-17 11:36:57 +02:00
KITAITI Makoto 0d14756929
ruby : add support for Parakeet (#3885)
* Add Whisper::Parakeet::Params

* Add tests for Parakeet::Params

* Remove unused variabel

* Add callbacks to Parakeet::Params

* Group callback and user_data params

* Undefine local macros

* Define GetParakeetParams

* Remove unused variable

* Use ITERATE_CALLBACK_PARAMS

* Use ITERATE_CALLBACK_PARAMS instead of ITERATE_USER_DATA_PARAMS

* Fix memsize

* Remove unnecessary macros

* Simplify params registration

* Define Parakeet

* Add hook methods to Parakeet::Params

* Fix typo

* Check callback container in GetParakeetParams

* Reduce if

* Free parakeet_full_params

* Implement Parakeet::Context#initialize

* Add TestParakeetContext

* Add Parakeet::Segment

* Prevent double-free

* Add Parakeet::Context#transcribe

* Add Parakeet::Context#each_segment

* Define Parakeet::Segment attributes

* Define Parakeet::Segment#deconstruct_keys

* Add tests for Parakeet::Segment#deconstruct_keys

* Run Parakeet::Context#transcribe without GVL

* Make it to abort for Parakeet

* Add Parakeet.log_set

* Define Parakeet::Token

* Define Parakeet::Segment#each_token

* Implement some hooks of Parakeet::Params

* Convert int to VALUE

* Implement hooks for Parakeet

* Implement Parakeet::Context#full

* Add tests for Parakeet::Context#full

* Add Parakeet to RBS

* Fix ruby_whisper_parakeet_params_free

* Free ruby_whisper_parakeet_context

* Add tests for hooks

* Add Parakeet section to README

* Add more attributes of Parakeet::Context

* Add tests for Parakeet::Context's attributes

* Update RBS

* Register parakeet-tdt-0.6b-v3

* Narrow scope of log constants

* Extract activate and deactivate of log_queue

* Make start_log_callback_thread private

* Don't call start_log_callback_thread unncecessarilly

* Early return from log_queue_enqueue when not active

* Gropu log_queue members

* is_active -> is_open

* Fix English

* Share parakeet full body function

* ruby_whisper_parakeet_abort_callback_user_data -> ruby_whisper_abort_callback_user_data

* NULL check for callback containers

* Fix Parakeet.log_set

* Omit Parakeet tests on CI

* Extract Whisper::LogSettable

* Join log callback thread in a log queue function

* Revert Join log callback thread in a log queue function

* Extract output methods to modules

* Move Parakeet init functions into init_parakeet()

* Add output methods to Parakeet classes

* Add Parakeet's output methods to RBS

* Use Whisper::Output in RBS

* Add LogSettable to RBS

* Fix module Token -> class Token

* Add Parakeet::Model

* Add test for Parakeet::Model

* Add Parakeet::Model to RBS

* Move position of Parakeet::Model in RBS

* Parakeet -> TestBase::Parakeet

* Add Parakeet::Context#model in RBS

* Add Whisper::Output

* Fix nil check

* Define ruby_whisper_parakeet_model_memsize

* Fix order of declaration in ruby_whisper_parakeet_model_get_xxx

* Define Parakeet.system_info_str

* Add test for Parakeet.system_info_str

* Add signature of Parakeet.system_info_str

* Define Parakeet::VERSION

* Add test for Parakeet::VERSION

* Add signature of Parakeet::VERSION

* Add Parakeet::Context::Params

* Make Parakeet::Context.new accept Context::Params

* Add test for Parakeet::Context.new with Context::Params

* Update RBS

* Remove params from Parakeet::Params which are moved from whisper_parakeet_full_params

* Remove tests for removed params

* Make Parakeet tests follow original behavior changes

* Add Parakeet model shortcuts

* Alloc token data in factory instead of alloc func

* Fix variable name

* Update RBS

* Refactor log settable module

* Use log settable for Whisper

* Address deadlock

* Make test follow change of log queue implementation

* Refactor to make abort callback use the same way to parakeet's way

* Remove redundant structs

* Fix test name

* Fix README

* Add missing parallel transcription

* Fix test for parakeet info

* Remove removed params

* Wait for logs dequeued

* Fix instance variable name

* Load etc feature

* Remove unnecessary comment

* Remove unnecessary thread safety check

* Remove outdated comment

* Skip downloading model if cache exists

* Change Hugging Face URI for Parakeet models

* Bump required Ruby version to 3.3

* Fix English
2026-06-17 06:42:09 +02:00
Daniel Bevenius 9efddafb91
parakeet : add support for NVIDIA Parakeet (#3735)
* parakeet : add support for NVIDIA Parakeet


Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-16 20:44:10 +02:00
Daniel Bevenius 3805e602d3
ci : only trigger release jobs for tags (#3883)
* ci : only trigger release jobs for tags

This commit removes the building of the release jobs on pushed to
master.

The motivation for this is that it can be confusing at the momement when
releasing that the push to master also triggers the release jobs but
the actual release will be skipped. With this change the release job is
only run when a tag is pushed which should result in a single Release
github actions job and make it easier to follow.

* ci : add GGML_NATIVE=OFF for ubuntu-22-gcc
2026-06-16 14:33:42 +02:00
80 changed files with 11753 additions and 348 deletions

View File

@ -27,6 +27,6 @@ jobs:
steps:
- uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0
with:
ruby-version: '3.2'
ruby-version: '3.3'
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6
- run: rake test

View File

@ -75,6 +75,7 @@ jobs:
apt update
apt install -y build-essential cmake libsdl2-dev git ccache
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \
-DGGML_NATIVE=OFF \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache
make

View File

@ -13,8 +13,6 @@ on:
type: string
push:
branches:
- master
tags:
- 'v*'
@ -117,9 +115,11 @@ jobs:
run: |
cmake -B build \
-DCMAKE_BUILD_TYPE=Release \
-DBUILD_SHARED_LIBS=OFF \
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DGGML_BACKEND_DL=ON \
-DGGML_NATIVE=OFF \
${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }}
${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }}
cmake --build build --config Release -j $(nproc)
- name: Pack artifacts
@ -175,7 +175,7 @@ jobs:
-DBUILD_SHARED_LIBS=ON
-DWHISPER_SDL2=${{ matrix.sdl2 }}
-DGGML_NATIVE=OFF
-DGGML_BMI2=OFF
${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }}
- name: Build
run: |
@ -289,6 +289,8 @@ jobs:
-DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib"
-DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include"
-DWHISPER_SDL2=${{ matrix.sdl2 }}
-DGGML_NATIVE=OFF
${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }}
- name: Build
run: |
@ -492,7 +494,10 @@ jobs:
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%" ^
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%"
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^
-DGGML_BACKEND_DL=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_CPU_ALL_VARIANTS=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.8.7)
project("whisper.cpp" VERSION 1.9.1)
include(CheckIncludeFileCXX)
set(SOVERSION 1)
@ -19,6 +19,7 @@ endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(WHISPER_STANDALONE ON)
@ -180,12 +181,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location
get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
install(TARGETS whisper LIBRARY PUBLIC_HEADER)
target_compile_definitions(whisper PRIVATE
WHISPER_VERSION="${PROJECT_VERSION}"
)
set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h)
install(TARGETS parakeet LIBRARY PUBLIC_HEADER)
target_compile_definitions(parakeet PRIVATE
PARAKEET_VERSION="${PROJECT_VERSION}"
)
configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
@ -211,6 +220,35 @@ configure_file(cmake/whisper.pc.in
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc"
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet
PATH_VARS
PARAKEET_INCLUDE_INSTALL_DIR
PARAKEET_LIB_INSTALL_DIR
PARAKEET_BIN_INSTALL_DIR)
write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
VERSION ${WHISPER_INSTALL_VERSION}
COMPATIBILITY SameMajorVersion)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet)
configure_file(cmake/parakeet.pc.in
"${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
@ONLY)
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
#
# programs, examples and tests
#

View File

@ -7,7 +7,7 @@
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.8.7",
"version": "1.9.1",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

View File

@ -396,6 +396,37 @@ whisper
.full(Whisper::Params.new, samples)
```
### Parakeet ###
whispercpp gem now supports NVIDIA's ASR model Parakeet.
If you want to use Parakeet instead of Whisper, the API should feel familiar.
In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way.
```ruby
require "whisper"
# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem.
Parakeet = Whisper::Parakeet
parakeet = Parakeet::Context.new("path/to/model")
params = Parakeet::Params.new(
no_context: true
)
parakeet
.transcribe("path/to/audio.wav", params)
.each_segment do |segment|
puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}"
end
```
The main differences are:
* Namespace is `Whisper::Parakeet`.
* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks.
Custom context params
---------------------

View File

@ -84,6 +84,21 @@ else
end
end
TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin"
TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin"))
TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d")
directory TEST_PARAKEET_MODEL_DIR
if File.exist? TEST_PARAKEET_MODEL_SRC
file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t|
symlink t.source, t.name
end
else
require "open-uri"
file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t|
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read
end
end
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
chdir "test/jfk_reader" do
@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
end
CLEAN.include TEST_MEMORY_VIEW
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL]

View File

@ -30,6 +30,6 @@ create_makefile "whisper" do |conf|
#{libs}: cmake-targets
cmake-targets:
#{"\t"}"#{cmake}" -S sources -B build #{options}
#{"\t"}"#{cmake}" --build build --config Release --target common whisper
#{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet
EOF
end

View File

@ -1,19 +1,29 @@
#include "ruby_whisper.h"
VALUE mWhisper;
VALUE mLogSettable;
VALUE mVAD;
VALUE mParakeet;
VALUE cContext;
VALUE cParams;
VALUE cVADContext;
VALUE cVADParams;
VALUE cVADSegments;
VALUE cVADSegment;
VALUE cParakeetContext;
VALUE cParakeetContextParams;
VALUE cParakeetParams;
VALUE cParakeetSegment;
VALUE cParakeetModel;
VALUE eError;
VALUE cSegment;
VALUE cToken;
VALUE cModel;
VALUE mOutputContext;
VALUE mOutputSegment;
ID id_to_s;
ID id_call;
ID id___method__;
@ -27,9 +37,11 @@ ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_n_processors;
static bool is_log_callback_finalized = false;
static bool is_ruby_log_callback_present = false;
ID id_extended;
ID id_start_log_callback_thread;
ID id_log_callback_thread;
ID id_alive_p;
ID id_join;
// High level API
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD);
extern void init_ruby_whisper_vad_context(VALUE *mVAD);
extern void init_ruby_whisper_vad_segment(VALUE *mVAD);
extern void init_ruby_whisper_vad_segments(VALUE *mVAD);
extern void init_ruby_whisper_parakeet(VALUE *mWhisper);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
static ruby_whisper_log_queue whisper_log_queue;
LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set)
/*
* call-seq:
* lang_max_id -> Integer
@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) {
return rb_str_new2(whisper_print_system_info());
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true;
return Qnil;
}
typedef struct {
int level;
const char * buffer;
} call_log_callbacks_args;
static void*
call_log_callbacks(void *v_args) {
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
if (NIL_P(log_callback)) {
return NULL;
}
call_log_callbacks_args *args = (call_log_callbacks_args *)v_args;
VALUE user_data = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data);
return NULL;
}
static void
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
if (is_log_callback_finalized) {
return;
}
if (!is_ruby_log_callback_present) {
return;
}
call_log_callbacks_args args = {
level,
buffer,
};
if (ruby_thread_has_gvl_p()) {
call_log_callbacks((void *)&args);
} else {
rb_thread_call_with_gvl(call_log_callbacks, (void *)&args);
}
}
/*
* call-seq:
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
rb_iv_set(self, "log_callback", log_callback);
rb_iv_set(self, "user_data", user_data);
if (!NIL_P(log_callback)) {
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
}
if (NIL_P(log_callback)) {
whisper_log_set(NULL, NULL);
is_ruby_log_callback_present = false;
} else {
whisper_log_set(ruby_whisper_log_callback, NULL);
is_ruby_log_callback_present = true;
}
return Qnil;
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
@ -189,9 +133,19 @@ void Init_whisper() {
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_n_processors = rb_intern("n_processors");
id_extended = rb_intern("extended");
id_start_log_callback_thread = rb_intern("start_log_callback_thread");
id_log_callback_thread = rb_intern("@log_callback_thread");
id_alive_p = rb_intern("alive?");
id_join = rb_intern("join");
mWhisper = rb_define_module("Whisper");
rb_require("whisper/log_settable");
mLogSettable = rb_path2class("Whisper::LogSettable");
mVAD = rb_define_module_under(mWhisper, "VAD");
rb_require("whisper/output");
mOutputContext = rb_path2class("Whisper::Output::Context");
mOutputSegment = rb_path2class("Whisper::Output::Segment");
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
@ -222,8 +176,8 @@ void Init_whisper() {
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
LOG_SETTABLE_INIT(whisper_log_queue, mWhisper)
cContext = init_ruby_whisper_context(&mWhisper);
init_ruby_whisper_context_params(&cContext);
@ -236,8 +190,10 @@ void Init_whisper() {
init_ruby_whisper_vad_segment(&mVAD);
init_ruby_whisper_vad_segments(&mVAD);
init_ruby_whisper_vad_context(&mVAD);
init_ruby_whisper_parakeet(&mWhisper);
rb_require("whisper/context");
rb_require("whisper/segment");
rb_require("whisper/model/uri");
rb_include_module(cContext, mOutputContext);
rb_include_module(cSegment, mOutputSegment);
}

View File

@ -5,8 +5,12 @@
#include <ruby/version.h>
#include <ruby/util.h>
#include <ruby/thread.h>
#include <ruby/thread_native.h>
#include <ruby/atomic.h>
#include <ruby/memory_view.h>
#include "whisper.h"
#include "parakeet.h"
#include "ruby_whisper_log_settable.h"
#if RUBY_API_VERSION_MAJOR < 4
// Exists but not declared as public API
@ -20,13 +24,28 @@ typedef struct {
VALUE callbacks;
} ruby_whisper_callback_container;
typedef struct {
VALUE *context;
VALUE user_data;
VALUE callback;
VALUE callbacks;
bool is_interrupted;
} ruby_whisper_abort_callback_container;
typedef struct ruby_whisper_abort_callback_user_data {
volatile rb_atomic_t is_interrupted;
ruby_whisper_callback_container *callback_container;
} ruby_whisper_abort_callback_user_data;
typedef struct ruby_whisper_log {
enum ggml_log_level level;
char *text;
size_t length;
size_t capacity;
} ruby_whisper_log;
typedef struct ruby_whisper_log_queue {
rb_nativethread_lock_t lock;
rb_nativethread_cond_t cond;
bool is_open;
size_t head;
size_t tail;
size_t size;
ruby_whisper_log *logs;
} ruby_whisper_log_queue;
typedef struct {
struct whisper_context *context;
@ -42,7 +61,7 @@ typedef struct {
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_callback_container *abort_callback_container;
VALUE vad_params;
} ruby_whisper_params;
@ -84,6 +103,63 @@ typedef struct parsed_samples_t {
bool memview_exported;
} parsed_samples_t;
typedef struct {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} ruby_whisper_full_args;
typedef struct ruby_whisper_full_parallel_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
int n_processors;
} ruby_whisper_full_parallel_args;
typedef struct {
struct parakeet_full_params params;
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *new_token_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_parakeet_params;
typedef struct {
struct parakeet_context_params params;
} ruby_whisper_parakeet_context_params;
typedef struct {
struct parakeet_context *context;
} ruby_whisper_parakeet_context;
typedef struct {
VALUE context;
int index;
} ruby_whisper_parakeet_segment;
typedef struct {
parakeet_token_data *token_data;
VALUE text;
} ruby_whisper_parakeet_token;
typedef struct {
VALUE context;
} ruby_whisper_parakeet_model;
extern ID id_extended;
extern ID id_log_callback_thread;
extern ID id_start_log_callback_thread;
extern ID id_alive_p;
extern ID id_join;
extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue);
extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text);
extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue);
#define GetContext(obj, rw) do { \
TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \
if ((rw)->context == NULL) { \
@ -120,4 +196,47 @@ typedef struct parsed_samples_t {
} \
} while (0)
#define GetParakeetContextParams(obj, rwpcp) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \
} while (0)
#define GetParakeetContext(obj, rwpc) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \
if ((rwpc)->context == NULL) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetParams(obj, rwpp) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \
if (!(rwpp)->new_segment_callback_container || \
!(rwpp)->new_token_callback_container || \
!(rwpp)->progress_callback_container || \
!(rwpp)->encoder_begin_callback_container || \
!(rwpp)->abort_callback_container) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetSegment(obj, rwps) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \
if (!(rwps)->context) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetToken(obj, rwpt) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \
if (!(rwpt)->token_data) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#define GetParakeetModel(obj, rwpm) do { \
TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \
if (NIL_P((rwpm)->context)) { \
rb_raise(rb_eRuntimeError, "Not initialized"); \
} \
} while (0)
#endif

View File

@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
ID transcribe_option_names[1];
@ -38,21 +38,6 @@ typedef struct fill_samples_args {
int n_samples;
} fill_samples_args;
typedef struct full_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
} full_args;
typedef struct full_parallel_args {
VALUE *context;
VALUE *params;
float *samples;
int n_samples;
int n_processors;
} full_parallel_args;
typedef struct full_without_gvl_args {
struct whisper_context *context;
struct whisper_full_params *params;
@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args {
} full_parallel_without_gvl_args;
typedef struct full_ubf_args {
ruby_whisper_abort_callback_container *abort_callback_container;
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
} full_ubf_args;
static void
@ -379,7 +364,7 @@ fill_samples(VALUE rb_args)
return Qnil;
}
struct parsed_samples_t
parsed_samples_t
parse_samples(VALUE *samples, VALUE *n_samples)
{
bool memview_available = rb_memory_view_available_p(*samples);
@ -480,20 +465,24 @@ full_ubf(void *rb_args)
{
full_ubf_args *args = (full_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
}
static VALUE
VALUE
full_body(VALUE rb_args)
{
full_args *args = (full_args *)rb_args;
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
ruby_whisper *rw;
ruby_whisper_params *rwp;
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, 1);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, 1, &abort_callback_user_data);
struct full_without_gvl_args full_without_gvl_args = {
rw->context,
@ -503,7 +492,7 @@ full_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_without_gvl_args.result);
@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
VALUE n_samples = argc == 2 ? Qnil : argv[2];
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
full_args args = {
ruby_whisper_full_args args = {
&self,
&argv[0],
parsed.samples,
@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args)
return NULL;
}
static VALUE
VALUE
full_parallel_body(VALUE rb_args)
{
full_parallel_args *args = (full_parallel_args *)rb_args;
ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args;
ruby_whisper *rw;
ruby_whisper_params *rwp;
GetContext(*args->context, rw);
TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
prepare_transcription(rwp, args->context, args->n_processors);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data);
struct full_parallel_without_gvl_args full_parallel_without_gvl_args = {
rw->context,
@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args)
0,
};
full_ubf_args full_ubf_args = {
rwp->abort_callback_container,
&abort_callback_user_data,
};
rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_parallel_without_gvl_args.result);
@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
break;
}
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
const full_parallel_args args = {
const ruby_whisper_full_parallel_args args = {
&self,
&argv[0],
parsed.samples,

View File

@ -0,0 +1,180 @@
#include "ruby_whisper.h"
#define LOG_QUEUE_CAPACITY 256
#define LOG_DEFAULT_CAPACITY 1024
void
ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_initialize(&log_queue->lock);
rb_native_cond_initialize(&log_queue->cond);
log_queue->head = 0;
log_queue->tail = 0;
log_queue->size = 0;
log_queue->is_open = true;
log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY);
for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) {
// we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL
// this doesn't be freed because log queue lives until the end of process
char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY);
if (!slot) {
rb_raise(rb_eRuntimeError, "Could not allocate memory for log text");
}
ruby_whisper_log log = {
0,
slot,
0,
LOG_QUEUE_CAPACITY,
};
log_queue->logs[i] = log;
}
}
void
ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_lock(&log_queue->lock);
log_queue->is_open = true;
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
void
ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue)
{
rb_nativethread_lock_lock(&log_queue->lock);
log_queue->is_open = false;
rb_native_cond_broadcast(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
static size_t
calc_enough_cap(size_t len)
{
size_t quot = len / LOG_DEFAULT_CAPACITY;
size_t rem = len % LOG_DEFAULT_CAPACITY;
return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY;
}
void
ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text)
{
rb_nativethread_lock_lock(&log_queue->lock);
if (!log_queue->is_open) {
rb_nativethread_lock_unlock(&log_queue->lock);
return;
}
size_t len = strlen(text);
ruby_whisper_log *log = &log_queue->logs[log_queue->head];
if (len > log->capacity) {
size_t new_cap = calc_enough_cap(len);
// we cannot call Ruby API like REALLOC_N because this function is called without GVL
char *slot = realloc(log->text, new_cap);
if (!slot) {
rb_nativethread_lock_unlock(&log_queue->lock);
return;
}
log->text = slot;
log->capacity = new_cap;
}
// we cannot call Ruby API like MEMCPY because this function is called without GVL
memcpy(log->text, text, sizeof(char) * len);
log->length = len;
log->level = level;
log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY;
bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY;
log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1;
if (is_full) {
log_queue->tail = log_queue->head;
}
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
}
static void*
ruby_whisper_log_queue_wait(void *args)
{
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
rb_native_cond_wait(&log_queue->cond, &log_queue->lock);
rb_nativethread_lock_unlock(&log_queue->lock);
return NULL;
}
static void
ruby_whisper_log_queue_wait_ubf(void *args)
{
ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args;
rb_native_cond_broadcast(&log_queue->cond);
}
typedef struct {
enum ggml_log_level level;
size_t length;
char *text;
} log_snapshot;
VALUE
ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue)
{
log_snapshot logs[LOG_QUEUE_CAPACITY];
rb_nativethread_lock_lock(&log_queue->lock);
while (log_queue->size == 0 && log_queue->is_open) {
rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue);
rb_nativethread_lock_lock(&log_queue->lock);
}
if (log_queue->size == 0 && !log_queue->is_open) {
rb_native_cond_broadcast(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
return Qnil;
}
size_t size = log_queue->size;
ruby_whisper_log *log;
size_t i;
for (i = 0; i < size; i++) {
log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY];
logs[i].level = log->level;
logs[i].length = log->length;
char *text = malloc(log->length);
if (!text) {
logs[i].text = NULL;
continue;
}
logs[i].text = text;
memcpy(logs[i].text, log->text, log->length);
}
log_queue->size = 0;
log_queue->tail = log_queue->head;
rb_native_cond_signal(&log_queue->cond);
rb_nativethread_lock_unlock(&log_queue->lock);
VALUE rb_logs = rb_ary_new2(size);
VALUE rb_text;
for (i = 0; i < size; i++) {
if (!logs[i].text) {
continue;
}
rb_text = rb_str_new(logs[i].text, logs[i].length);
free(logs[i].text);
rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text));
}
return rb_logs;
}

View File

@ -0,0 +1,47 @@
#ifndef RUBY_WHISPER_LOG_SETTABLE_H
#define RUBY_WHISPER_LOG_SETTABLE_H
#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \
static VALUE \
ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \
{ \
return ruby_whisper_log_queue_drain(&log_queue); \
} \
static void \
ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \
{ \
ruby_whisper_log_queue_enqueue(&log_queue, level, text); \
} \
static VALUE \
ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \
{ \
rb_iv_set(self, "@log_callback", log_callback); \
rb_iv_set(self, "@log_callback_user_data", user_data); \
if (NIL_P(log_callback)) { \
log_set(NULL, NULL); \
} else { \
ruby_whisper_log_queue_open(&log_queue); \
rb_funcall((mod), id_start_log_callback_thread, 0); \
log_set(ruby_whisper_##log_queue##_log_callback, NULL); \
} \
return Qnil; \
} \
static void \
ruby_whisper_##log_queue##_end_proc(VALUE args) \
{ \
ruby_whisper_log_queue_close(&log_queue); \
VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \
if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \
rb_funcall(log_callback_thread, id_join, 0); \
} \
}
#define LOG_SETTABLE_INIT(log_queue, mod) \
ruby_whisper_log_queue_initialize(&log_queue); \
rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \
rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \
rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \
rb_extend_object(mod, mLogSettable); \
rb_funcall(mLogSettable, id_extended, 1, mod);
#endif

View File

@ -0,0 +1,49 @@
#include "ruby_whisper.h"
#include <stdio.h>
#include <unistd.h>
extern VALUE mParakeet;
extern VALUE mLogSettable;
extern VALUE cParakeetContext;
extern VALUE cParakeetSegment;
extern VALUE mOutputContext;
extern VALUE mOutputSegment;
extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet);
extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet);
extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext);
extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet);
static ruby_whisper_log_queue parakeet_log_queue;
LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set)
static VALUE
ruby_whisper_parakeet_s_system_info_str(VALUE self)
{
return rb_str_new2(parakeet_print_system_info());
}
void
init_ruby_whisper_parakeet(VALUE *mWhisper)
{
mParakeet = rb_define_module_under(*mWhisper, "Parakeet");
rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version()));
LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet)
rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0);
init_ruby_whisper_parakeet_params(&mParakeet);
init_ruby_whisper_parakeet_token(&mParakeet);
init_ruby_whisper_parakeet_segment(&mParakeet);
cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet);
init_ruby_whisper_parakeet_context_params(&cParakeetContext);
init_ruby_whisper_parakeet_model(&mParakeet);
rb_include_module(cParakeetContext, mOutputContext);
rb_include_module(cParakeetSegment, mOutputSegment);
}

View File

@ -0,0 +1,304 @@
#include "ruby_whisper.h"
#define ITERATE_SEGMENT_ATTRS(ITERATOR) \
ITERATOR(get_segment_t0, LONG) \
ITERATOR(get_segment_t1, LONG) \
ITERATOR(get_segment_text, STRING) \
ITERATOR(n_tokens, INT)
#define ITERATE_TOKEN_ATTRS(ITERATOR) \
ITERATOR(get_token_text, STRING) \
ITERATOR(get_token_id, INT) \
ITERATOR(get_token_p, FLOAT)
#define VAL_FROM_LONG(v) LONG2NUM(v)
#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v)
#define VAL_FROM_INT(v) INT2NUM(v)
#define VAL_FROM_FLOAT(v) DBL2NUM(v)
#define READER(type) VAL_FROM_##type
extern ID id_to_s;
extern ID id___method__;
extern ID id_to_enum;
extern ID id_new;
extern VALUE cParakeetContext;
extern VALUE eError;
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples);
extern VALUE release_samples(VALUE rb_parsed_args);
extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data);
extern rb_data_type_t ruby_whisper_parakeet_params_type;
extern rb_data_type_t ruby_whisper_parakeet_context_params_type;
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context);
static void
ruby_whisper_parakeet_context_free(void *p)
{
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
if (rwpc->context) {
parakeet_free(rwpc->context);
rwpc->context = NULL;
}
xfree(rwpc);
}
static size_t
ruby_whisper_parakeet_context_memsize(const void *p)
{
ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p;
if (!rwpc) {
return 0;
}
size_t size = sizeof(*rwpc);
return size;
}
const rb_data_type_t ruby_whisper_parakeet_context_type = {
"ruby_whisper_parakeet_context",
{0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_context_allocate(VALUE klass)
{
ruby_whisper_parakeet_context *rwpc;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
rwpc->context = NULL;
return obj;
}
typedef struct {
struct parakeet_context **context;
char *model_path;
struct parakeet_context_params params;
} ruby_whisper_parakeet_context_init_args;
static void*
ruby_whisper_parakeet_context_init_without_gvl(void *args)
{
ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args;
*init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params);
return NULL;
}
static VALUE
ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper_parakeet_context *rwpc;
VALUE model_path;
VALUE context_params;
struct parakeet_context_params params;
rb_scan_args(argc, argv, "11", &model_path, &context_params);
TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc);
model_path = ruby_whisper_normalize_model_path(model_path);
if (!rb_respond_to(model_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context");
}
if (NIL_P(context_params)) {
params = parakeet_context_default_params();
} else {
ruby_whisper_parakeet_context_params *rwpcp;
GetParakeetContextParams(context_params, rwpcp);
params = rwpcp->params;
}
ruby_whisper_parakeet_context_init_args init_args = {
&rwpc->context,
StringValueCStr(model_path),
params,
};
rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL);
if (rwpc->context == NULL) {
rb_raise(rb_eRuntimeError, "Failed to load model");
}
return Qnil;
}
static VALUE
ruby_whisper_parakeet_context_full_n_segments(VALUE self)
{
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
return INT2NUM(parakeet_full_n_segments(rwpc->context));
}
#define DEF_SEGMENT_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \
{ \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(self, rwpc); \
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \
}
ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR)
#define DEF_TOKEN_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \
{ \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(self, rwpc); \
return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \
}
ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR)
static VALUE
ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token)
{
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token));
return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data);
}
static VALUE
ruby_whisper_parakeet_context_each_segment(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(self, rwpc);
const int n_segments = parakeet_full_n_segments(rwpc->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(ruby_whisper_parakeet_segment_init(self, i));
}
return self;
}
typedef struct {
struct parakeet_context *context;
struct parakeet_full_params *params;
float *samples;
int n_samples;
int result;
} parakeet_full_without_gvl_args;
static void*
parakeet_full_without_gvl(void *rb_args)
{
parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args;
args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples);
return NULL;
}
typedef struct {
ruby_whisper_abort_callback_user_data *abort_callback_user_data;
} parakeet_full_ubf_args;
static void
parakeet_full_ubf(void *rb_args)
{
parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args;
RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1);
}
VALUE
ruby_whisper_parakeet_context_full_body(VALUE rb_args)
{
ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args;
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(*args->context, rwpc);
ruby_whisper_parakeet_params *rwpp;
GetParakeetParams(*args->params, rwpp);
ruby_whisper_abort_callback_user_data abort_callback_user_data = {
0,
NULL,
};
ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data);
parakeet_full_without_gvl_args full_without_gvl_args = {
rwpc->context,
&rwpp->params,
args->samples,
args->n_samples,
0
};
parakeet_full_ubf_args full_ubf_args = {
&abort_callback_user_data,
};
rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args);
return INT2NUM(full_without_gvl_args.result);
}
static VALUE
ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self)
{
if (argc < 2 || argc > 3) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
VALUE n_samples = argc == 2 ? Qnil : argv[2];
struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples);
ruby_whisper_full_args args = {
&self,
&argv[0],
parsed.samples,
parsed.n_samples,
};
VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed);
const int result = NUM2INT(rb_result);
if (result == 0) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
}
}
static VALUE
ruby_whisper_parakeet_context_get_model(VALUE self)
{
return ruby_whisper_parakeet_model_s_new(self);
}
VALUE
init_ruby_whisper_parakeet_context(VALUE *mParakeet)
{
cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject);
rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate);
rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1);
rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2);
rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0);
rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2);
rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0);
rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0);
rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1);
#define REGISTER_SEGMENT_ATTR(name, type) \
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1);
ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR)
#define REGISTER_TOKEN_ATTR(name, type) \
rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2);
ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR)
return cParakeetContext;
}

View File

@ -0,0 +1,117 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(use_gpu, BOOL) \
ITERATOR(gpu_device, INT)
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
#define VAL_TO_BOOL(v) (RTEST(v))
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_TO_INT(v) (NUM2INT(v))
#define READER(type) VAL_FROM_##type
#define WRITER(type) VAL_TO_##type
#define DEF_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_context_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_context_params *rwpcp; \
GetParakeetContextParams(self, rwpcp); \
return READER(type)(rwpcp->params.name); \
} \
static VALUE \
ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_context_params *rwpcp; \
GetParakeetContextParams(self, rwpcp); \
rwpcp->params.name = WRITER(type)(val); \
return val; \
}
enum {
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name,
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS
};
extern VALUE cParakeetContextParams;
typedef VALUE (*param_writer_t)(VALUE, VALUE);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS];
static size_t
ruby_whisper_parakeet_context_params_memsize(const void *p)
{
if (!p) {
return 0;
}
return sizeof(ruby_whisper_parakeet_context_params);
}
const rb_data_type_t ruby_whisper_parakeet_context_params_type = {
"ruby_whisper_parakeet_context_params",
{0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,},
0, 0,
0,
};
static VALUE
ruby_whisper_parakeet_context_params_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_context_params *rwpcp;
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
}
static VALUE
ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_parakeet_context_params *rwpcp;
int i;
TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp);
rwpcp->params = parakeet_context_default_params();
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return Qnil;
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values);
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) {
value = values[i];
if (value == Qundef) {
continue;
}
param_writers[i](self, value);
}
return Qnil;
}
ITERATE_ATTRS(DEF_ATTR)
void
init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext)
{
cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject);
rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate);
rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1);
int i = 0;
#define REGISTER_ATTR(name, type) \
param_names[i] = rb_intern(#name); \
param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \
rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \
rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \
i++;
ITERATE_ATTRS(REGISTER_ATTR)
}

View File

@ -0,0 +1,84 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(n_vocab) \
ITERATOR(n_audio_ctx) \
ITERATOR(n_audio_state) \
ITERATOR(n_audio_head) \
ITERATOR(n_audio_layer) \
ITERATOR(n_mels) \
ITERATOR(ftype)
extern rb_data_type_t ruby_whisper_parakeet_context_type;
extern VALUE cParakeetModel;
static void
ruby_whisper_parakeet_model_mark(void *p)
{
ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p;
if (!NIL_P(rwpm->context)) {
rb_gc_mark(rwpm->context);
}
}
static size_t
ruby_whisper_parakeet_model_memsize(const void *p)
{
if (!p) {
return 0;
}
return sizeof(ruby_whisper_parakeet_model);
}
static const rb_data_type_t ruby_whisper_parakeet_model_type = {
"ruby_whisper_parakeet_model",
{ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_model_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_model *rwpm;
VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
rwpm->context = Qnil;
return model;
}
VALUE
ruby_whisper_parakeet_model_s_new(VALUE context)
{
const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel);
ruby_whisper_parakeet_model *rwpm;
TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm);
rwpm->context = context;
return model;
}
#define DEF_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_model_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_model *rwpm; \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetModel(self, rwpm); \
GetParakeetContext(rwpm->context, rwpc); \
return INT2NUM(parakeet_model_##name(rwpc->context)); \
}
ITERATE_ATTRS(DEF_ATTR)
void
init_ruby_whisper_parakeet_model(VALUE *mParakeet)
{
cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject);
rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate);
#define REGISTER_ATTR(name) \
rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0);
ITERATE_ATTRS(REGISTER_ATTR)
}

View File

@ -0,0 +1,548 @@
#include "ruby_whisper.h"
#define ITERATE_PARAMS(ITERATOR) \
ITERATOR(n_threads, INT) \
ITERATOR(offset_ms, INT) \
ITERATOR(duration_ms, INT) \
ITERATOR(no_context, BOOL) \
ITERATOR(audio_ctx, INT)
#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \
ITERATOR(new_segment, DATA) \
ITERATOR(new_token, DATA) \
ITERATOR(progress, DATA) \
ITERATOR(encoder_begin, DATA)
#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback)
#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR)
#define ITERATE_CALLBACK_PARAMS(ITERATOR) \
ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \
ITERATOR(abort_callback)
enum {
#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name,
#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name,
#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data,
ITERATE_PARAMS(DEF_IDX)
ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK)
ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA)
RUBY_WHISPER_PARAKEET_NUM_PARAMS
};
#define VAL_TO_INT(v) (NUM2INT(v))
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_TO_BOOL(v) (RTEST(v))
#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse)
extern VALUE cParakeetParams;
extern ID id_call;
extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc);
extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void);
extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container);
extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index);
extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data);
static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef VALUE (*param_writer_t)(VALUE, VALUE);
static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS];
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int n_new;
} call_parakeet_new_segment_callbacks_args;
static void*
call_parakeet_new_segment_callbacks(void *v_args)
{
call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
const int n_segments = parakeet_full_n_segments_from_state(args->state);
for (int i = args->n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment);
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
}
}
return NULL;
}
static void
ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_segment_callbacks_args args = {
container,
state,
n_new,
};
rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_context *context;
struct parakeet_state *state;
const parakeet_token_data *token_data;
} call_parakeet_new_token_callbacks_args;
static void*
call_parakeet_new_token_callbacks(void *v_args)
{
call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args;
VALUE token = Qnil;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
if (NIL_P(token)) {
token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data);
}
for (int i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, token);
}
return NULL;
}
static void
ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_new_token_callbacks_args args = {
container,
context,
state,
token_data,
};
rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
int progress;
} call_parakeet_progress_callbacks_args;
static void*
call_parakeet_progress_callback(void *v_args)
{
call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
if (!NIL_P(container->callback)) {
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data);
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
rb_funcall(cb, id_call, 1, INT2NUM(args->progress));
}
return NULL;
}
static void
ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return;
}
call_parakeet_progress_callbacks_args args = {
container,
state,
progress,
};
rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct parakeet_state *state;
bool is_continued;
} call_parakeet_encoder_begin_callbacks_args;
static void*
call_parakeet_encoder_begin_callbacks(void *v_args)
{
call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
for (long i = 0; i < n_callbacks; i++) {
VALUE cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
args->is_continued = false;
return NULL;
}
}
return NULL;
}
static bool
ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data)
{
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!ruby_whisper_callback_container_is_present(container)) {
return true;
}
call_parakeet_encoder_begin_callbacks_args args = {
container,
state,
true,
};
rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args);
return args.is_continued;
}
typedef struct {
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_parakeet_abort_callbacks_args;
static void*
call_parakeet_abort_callbacks(void *v_args)
{
call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args;
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
if (NIL_P(container->callbacks)) {
return NULL;
}
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (n_callbacks == 0) {
return NULL;
}
VALUE cb;
for (long i = 0; i < n_callbacks; i++) {
cb = rb_ary_entry(container->callbacks, i);
result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
}
return NULL;
}
static bool
ruby_whisper_parakeet_abort_callback(void *user_data)
{
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_parakeet_abort_callbacks_args args = {
data->callback_container,
false,
};
rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args);
return args.is_interrupted;
}
#define CALLBACK_CONTAINER_NAME(name) name ## _container
void
ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
{
#define PARAM_NAME(name) name
#define USER_DATA_NAME(name) name##_user_data
#define REGISTER_CALLBACK(name) \
if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \
rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \
rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \
rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \
}
ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK)
if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) {
abort_callback_user_data->callback_container = rwpp->abort_callback_container;
}
rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback;
rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
static void
ruby_whisper_parakeet_params_mark(void *p)
{
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
#define MARK_CONTAINER(name) \
if (rwpp->name##_container) { \
ruby_whisper_callback_container_mark(rwpp->name##_container); \
}
ITERATE_CALLBACK_PARAMS(MARK_CONTAINER)
}
static void
ruby_whisper_parakeet_params_free(void *p)
{
ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p;
#define FREE_CONTAINER(name) \
if (rwpp->name##_container) { \
xfree(rwpp->name##_container); \
}
ITERATE_CALLBACK_PARAMS(FREE_CONTAINER)
xfree(rwpp);
}
static size_t
ruby_whisper_parakeet_params_memsize(const void *p)
{
const struct ruby_whisper_parakeet_params *params = p;
if (!params) {
return 0;
}
return sizeof(ruby_whisper_parakeet_params);
}
const rb_data_type_t ruby_whisper_parakeet_params_type = {
"ruby_whisper_parakeet_params",
{ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,},
0, 0,
0
};
#define READER(type) VAL_FROM_##type
#define WRITER(type) VAL_TO_##type
#define DEF_PARAM_ATTR(name, type) \
static VALUE \
ruby_whisper_parakeet_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return READER(type)(rwpp->params.name); \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->params.name = WRITER(type)(val); \
return val; \
}
#define DEF_CALLBACK_PARAM_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_params_get_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \
return val; \
}
#define DEF_USER_DATA_PARAM_ATTR(name) \
static VALUE \
ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \
} \
static VALUE \
ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \
return val; \
}
#define DEF_HOOK(name, data) \
static VALUE \
ruby_whisper_parakeet_params_on_##name(VALUE self) \
{ \
ruby_whisper_parakeet_params *rwpp; \
GetParakeetParams(self, rwpp); \
const VALUE blk = rb_block_proc(); \
if (NIL_P(rwpp->name##_callback_container->callbacks)) { \
rwpp->name##_callback_container->callbacks = rb_ary_new(); \
} \
rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \
return Qnil; \
}
ITERATE_PARAMS(DEF_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR)
ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _)
static VALUE
ruby_whisper_parakeet_params_abort_on(VALUE self)
{
ruby_whisper_parakeet_params *rwpp;
GetParakeetParams(self, rwpp);
const VALUE blk = rb_block_proc();
if (NIL_P(rwpp->abort_callback_container->callbacks)) {
rwpp->abort_callback_container->callbacks = rb_ary_new();
}
rb_ary_push(rwpp->abort_callback_container->callbacks, blk);
return Qnil;
}
static VALUE
ruby_whisper_parakeet_params_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_params *rwpp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
return obj;
}
static VALUE
ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_parakeet_params *rwpp;
int i;
TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp);
#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate();
ITERATE_CALLBACK_PARAMS(INIT_CONTAINER)
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return Qnil;
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values);
for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) {
value = values[i];
if (value == Qundef) {
continue;
}
param_writers[i](self, value);
}
return Qnil;
}
void
init_ruby_whisper_parakeet_params(VALUE *mParakeet)
{
cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject);
rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate);
rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1);
int i = 0;
#define REGISTER_PARAM(name) \
param_names[i] = rb_intern(#name); \
param_writers[i] = ruby_whisper_parakeet_params_set_##name; \
rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \
rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \
i++;
#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name)
#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name)
#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data)
ITERATE_PARAMS(REGISTER_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR)
ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR)
#define REGISTER_HOOK(name, data) \
rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0);
ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _)
rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0);
}

View File

@ -0,0 +1,157 @@
#include "ruby_whisper.h"
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(start_time, t0, TIME) \
ITERATOR(end_time, t1, TIME) \
ITERATOR(text, text, STRING)
enum {
#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name,
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
};
#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10))
#define VAL_FROM_STRING(v) (rb_str_new2(v))
#define READER(type) VAL_FROM_##type
#define DEF_ATTR(rb_name, c_name, type) \
static VALUE \
ruby_whisper_parakeet_get_##rb_name(VALUE self) \
{ \
ruby_whisper_parakeet_segment *rwps; \
GetParakeetSegment(self, rwps); \
ruby_whisper_parakeet_context *rwpc; \
GetParakeetContext(rwps->context, rwpc); \
return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \
}
extern ID id___method__;
extern ID id_to_enum;
extern VALUE cParakeetSegment;
extern VALUE sym_start_time;
extern VALUE sym_end_time;
extern VALUE sym_text;
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token);
static void
rb_whisper_parakeet_segment_mark(void *p)
{
ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p;
rb_gc_mark(rwps->context);
}
static size_t
ruby_whisper_parakeet_segment_memsize(const void *p)
{
const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p;
if (!rwps) {
return 0;
}
return sizeof(*rwps);
}
static const rb_data_type_t ruby_whisper_parakeet_segment_type = {
"ruby_whisper_parakeet_segment",
{rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_parakeet_segment_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_segment *rwps;
return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
}
VALUE
ruby_whisper_parakeet_segment_init(VALUE context, int index)
{
ruby_whisper_parakeet_segment *rwps;
const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment);
TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps);
rwps->context = context;
rwps->index = index;
return segment;
}
ITERATE_ATTRS(DEF_ATTR)
static VALUE
ruby_whisper_parakeet_segment_each_token(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper_parakeet_segment *rwps;
GetParakeetSegment(self, rwps);
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(rwps->context, rwpc);
const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index);
for (int i = 0; i < n_tokens; i++) {
rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i));
}
return self;
}
static VALUE
ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_parakeet_segment *rwps;
GetParakeetSegment(self, rwps);
ruby_whisper_parakeet_context *rwpc;
GetParakeetContext(rwps->context, rwpc);
VALUE hash = rb_hash_new();
long n_keys;
if (NIL_P(keys)) {
keys = rb_ary_new3(
RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS,
sym_start_time,
sym_end_time,
sym_text
);
n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) {
return hash;
}
}
for (int i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
#define CHECK_AND_SET_KEY(rb_name, c_name, type) \
if (key == sym_##rb_name) { \
rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \
}
ITERATE_ATTRS(CHECK_AND_SET_KEY)
}
return hash;
}
void
init_ruby_whisper_parakeet_segment(VALUE *mParakeet)
{
cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject);
rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate);
#define REGISTER_ATTR(rb_name, c_name, type) \
rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0);
ITERATE_ATTRS(REGISTER_ATTR)
rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0);
rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1);
}

View File

@ -0,0 +1,188 @@
#include "ruby_whisper.h"
#define ITERATE_MEMBERS(ITERATOR) \
ITERATOR(id, id, id, id, INT) \
ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \
ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \
ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \
ITERATOR(probability, probability, p, p, FLOAT) \
ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \
ITERATOR(start_time, start_time, start_time, t0, TIME) \
ITERATOR(end_time, end_time, end_time, t1, TIME) \
ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL)
#define ITERATE_ATTRS(ITERATOR) \
ITERATOR(text, text, text, text, STRING)
enum {
#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name,
ITERATE_MEMBERS(DEF_IDX)
ITERATE_ATTRS(DEF_IDX)
RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS,
};
#define VAL_FROM_INT(v) (INT2NUM(v))
#define VAL_FROM_FLOAT(v) (DBL2NUM(v))
#define VAL_FROM_TIME(v) (LONG2NUM(v * 10))
#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse)
#define VAL_FROM_STRING(v) (rb_str_new2(v))
#define READER(type) VAL_FROM_##type
#define MEMBER_NAME(name) name
#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \
}
#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \
static VALUE \
ruby_whisper_parakeet_token_get_##c_name(VALUE self) \
{ \
ruby_whisper_parakeet_token *rwpt; \
GetParakeetToken(self, rwpt); \
return rwpt->p_name; \
}
VALUE cParakeetToken;
#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key;
ITERATE_MEMBERS(DEC_ATTR_SYMS)
ITERATE_ATTRS(DEC_ATTR_SYMS)
static void
ruby_whisper_parakeet_token_mark(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
rb_gc_mark(rwpt->text);
}
static void
ruby_whisper_parakeet_token_free(void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (rwpt->token_data) {
xfree(rwpt->token_data);
rwpt->token_data = NULL;
}
xfree(rwpt);
}
static size_t
ruby_whisper_parakeet_token_memsize(const void *p)
{
ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p;
if (!rwpt) {
return 0;
}
size_t size = sizeof(*rwpt);
if (rwpt->token_data) {
size += sizeof(*rwpt->token_data);
}
return size;
}
static const rb_data_type_t ruby_whisper_parakeet_token_type = {
"ruby_whisper_parakeet_token",
{ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize},
0, 0,
0,
};
static VALUE
ruby_whisper_parakeet_token_s_allocate(VALUE klass)
{
ruby_whisper_parakeet_token *rwpt;
VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = NULL;
rwpt->text = Qnil;
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data)
{
const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken);
ruby_whisper_parakeet_token *rwpt;
TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt);
rwpt->token_data = ALLOC(parakeet_token_data);
*rwpt->token_data = *token_data;
rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id));
return token;
}
VALUE
ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token)
{
parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token);
return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data);
}
ITERATE_MEMBERS(DEF_MEMBER_ATTR)
// Define #text using parakeet_token_to_str or parakeet_token_to_text
ITERATE_ATTRS(DEF_ATTR)
static VALUE
ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_parakeet_token *rwpt;
GetParakeetToken(self, rwpt);
VALUE hash = rb_hash_new();
long n_keys = 0;
if (NIL_P(keys)) {
VALUE attrs[] = {
#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key,
ITERATE_MEMBERS(LIST_SYMS)
ITERATE_ATTRS(LIST_SYMS)
};
keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs);
n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) {
return hash;
}
}
for (long i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \
if (key == sym_##s_key) { \
rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \
}
ITERATE_MEMBERS(CHECK_AND_SET_KEY)
ITERATE_ATTRS(CHECK_AND_SET_KEY)
}
return hash;
}
void
init_ruby_whisper_parakeet_token(VALUE *mParakeet)
{
cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject);
rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate);
#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \
sym_##s_key = ID2SYM(rb_intern(#s_key)); \
rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0);
ITERATE_MEMBERS(REGISTER_ATTR)
ITERATE_ATTRS(REGISTER_ATTR)
rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1);
}

View File

@ -0,0 +1,58 @@
#include "ruby_whisper.h"
#include "common-whisper.h"
#include <string>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
extern const rb_data_type_t ruby_whisper_parakeet_context_type;
extern const rb_data_type_t ruby_whisper_parakeet_params_type;
extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args);
extern ID id_to_path;
extern ID id_new;
extern VALUE eError;
VALUE
ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params)
{
if (rb_respond_to(audio_path, id_to_path)) {
audio_path = rb_funcall(audio_path, id_to_path, 0);
}
std::string fname = StringValueCStr(audio_path);
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
if (!read_audio_data(fname, pcmf32, pcmf32s, false)) {
rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str());
return Qnil;
}
ruby_whisper_parakeet_context *rwpc;
ruby_whisper_parakeet_params *rwpp;
GetParakeetContext(self, rwpc);
GetParakeetParams(params, rwpp);
ruby_whisper_full_args args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
};
VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args);
const int result = NUM2INT(rb_result);
if (result == 0) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result));
}
}
#ifdef __cplusplus
}
#endif

View File

@ -76,8 +76,8 @@ static ID id_vad;
static ID id_vad_model_path;
static ID id_vad_params;
static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
void
ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc)
{
if (rwc == NULL) return;
@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_callback_container*
rb_whisper_callback_container_allocate() {
ruby_whisper_callback_container*
ruby_whisper_callback_container_allocate() {
ruby_whisper_callback_container *container;
container = ALLOC(ruby_whisper_callback_container);
container->context = NULL;
@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() {
return container;
}
static void
rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc)
{
if (rwc == NULL) return;
rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks);
}
static ruby_whisper_abort_callback_container*
rb_whisper_abort_callback_container_allocate() {
ruby_whisper_abort_callback_container *container;
container = ALLOC(ruby_whisper_abort_callback_container);
container->context = NULL;
container->user_data = Qnil;
container->callback = Qnil;
container->callbacks = Qnil;
container->is_interrupted = false;
return container;
}
static bool
bool
ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
static bool
ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) {
return !NIL_P(container->callback) || !NIL_P(container->callbacks);
}
typedef struct {
const ruby_whisper_callback_container *container;
struct whisper_state *state;
@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s
}
typedef struct {
const ruby_whisper_abort_callback_container *container;
struct whisper_state *state;
const ruby_whisper_callback_container *container;
bool is_interrupted;
} call_abort_callbacks_args;
static void*
call_abort_callbacks(void *v_args) {
call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args;
const ruby_whisper_abort_callback_container *container = args->container;
if (container->is_interrupted) {
args->is_interrupted = true;
return NULL;
}
const ruby_whisper_callback_container *container = args->container;
VALUE result = Qnil;
if (!NIL_P(container->callback)) {
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
result = rb_funcall(container->callback, id_call, 1, container->user_data);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) {
if (NIL_P(container->callbacks)) {
return NULL;
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
const long n_callbacks = RARRAY_LEN(container->callbacks);
if (0 == n_callbacks) {
return NULL;
}
for (int j = 0; j < callbacks_len; j++) {
for (int j = 0; j < n_callbacks; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
if (!NIL_P(result) && Qfalse != result) {
VALUE result = rb_funcall(cb, id_call, 0);
if (RTEST(result)) {
args->is_interrupted = true;
return NULL;
}
@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) {
}
static bool abort_callback(void * user_data) {
const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data;
ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data;
if (container->is_interrupted) {
int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted);
if (is_interrupted) {
return true;
}
if (!ruby_whisper_abort_callback_container_is_present(container)) {
if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) {
return false;
}
call_abort_callbacks_args args = {
container,
NULL,
data->callback_container,
false
};
rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args);
@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors)
return;
}
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription");
}
if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) {
rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription");
}
// new_segment_callback is called only after multiple threads are joined
// progress_callback is not called when parallel
if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) {
rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription");
}
if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) {
if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) {
rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription");
}
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
if (!NIL_P(log_callback)) {
rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription");
}
}
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) {
if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) {
rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback;
@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
abort_callback_user_data->callback_container = rwp->abort_callback_container;
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
rwp->abort_callback_container->is_interrupted = false;
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
rwp->params.abort_callback_user_data = (void *)abort_callback_user_data;
}
static void set_vad_params(ruby_whisper_params *rwp)
@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp)
rwp->params.vad_params = rwvp->params;
}
/*
TODO: Set abort callback to trap SIGINT and SIGTERM
*/
void
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors)
prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data)
{
check_thread_safety(rwp, n_processors);
register_callbacks(rwp, context);
register_callbacks(rwp, context, abort_callback_user_data);
set_vad_params(rwp);
}
@ -421,10 +376,10 @@ void
rb_whisper_params_mark(void *p)
{
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_abort_callback_container_mark(rwp->abort_callback_container);
ruby_whisper_callback_container_mark(rwp->new_segment_callback_container);
ruby_whisper_callback_container_mark(rwp->progress_callback_container);
ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container);
ruby_whisper_callback_container_mark(rwp->abort_callback_container);
rb_gc_mark(rwp->vad_params);
}
@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass)
}
rwp->diarize = false;
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate();
rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate();
rwp->progress_callback_container = ruby_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate();
rwp->abort_callback_container = ruby_whisper_callback_container_allocate();
return obj;
}

View File

@ -4,12 +4,12 @@
extern ID id___method__;
extern ID id_to_enum;
static VALUE sym_start_time;
static VALUE sym_end_time;
static VALUE sym_text;
static VALUE sym_no_speech_prob;
static VALUE sym_speaker_turn_next;
static VALUE sym_n_tokens;
VALUE sym_start_time;
VALUE sym_end_time;
VALUE sym_text;
VALUE sym_no_speech_prob;
VALUE sym_speaker_turn_next;
VALUE sym_n_tokens;
extern const rb_data_type_t ruby_whisper_type;

View File

@ -16,6 +16,8 @@ extern ID id_to_path;
extern ID transcribe_option_names[1];
extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors);
extern VALUE full_body(VALUE rb_args);
extern VALUE full_parallel_body(VALUE rb_args);
typedef struct{
struct whisper_context *context;
@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args)
return NULL;
}
typedef struct {
ruby_whisper_abort_callback_container *abort_callback_container;
} transcribe_ubf_args;
static void
transcribe_ubf(void *rb_args)
{
transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args;
args->abort_callback_container->is_interrupted = true;
}
/*
* transcribe a single file
* can emit to a block results
@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
// Commented out because it is work in progress
// {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data;
// return !is_aborted;
// };
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
prepare_transcription(rwp, &self, n_processors);
transcribe_without_gvl_args args = {
rw->context,
&rwp->params,
pcmf32.data(),
pcmf32.size(),
n_processors,
0,
};
transcribe_ubf_args ubf_args = {
rwp->abort_callback_container,
};
rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args);
if (args.result != 0) {
VALUE rb_result;
if (n_processors == 1) {
ruby_whisper_full_args args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
};
rb_result = full_body((VALUE)&args);
} else {
ruby_whisper_full_parallel_args parallel_args = {
&self,
&params,
pcmf32.data(),
(int)pcmf32.size(),
n_processors,
};
rb_result = full_parallel_body((VALUE)&parallel_args);
}
const int result = NUM2INT(rb_result);
if (result != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}

View File

@ -1,15 +0,0 @@
module Whisper
class Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
end

View File

@ -0,0 +1,36 @@
require "mutex_m"
module Whisper
module LogSettable
class << self
def extended(base)
base.extend Mutex_m
end
end
private
def start_log_callback_thread
return if @log_callback_thread&.alive?
@log_callback_thread = Thread.new {
begin
while logs = drain_logs
begin
callback, user_data = synchronize {[@log_callback, @log_callback_user_data]}
next if callback.nil?
logs.each do |(level, text)|
callback.call level, text, user_data
end
rescue => err
$stderr.puts err
end
end
rescue => err
$stderr.puts err
end
}
end
end
end

View File

@ -41,6 +41,8 @@ module Whisper
def cache
path = cache_path
return path if cache_path.exist?
headers = {}
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
request @uri, headers
@ -216,8 +218,18 @@ module Whisper
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
end
%w[
parakeet-tdt-0.6b-v3-f16
parakeet-tdt-0.6b-v3-f32
parakeet-tdt-0.6b-v3-q4_0
parakeet-tdt-0.6b-v3-q4_k
parakeet-tdt-0.6b-v3-q8_0
].each do |name|
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin")
end
@coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models|
next if name.end_with?("-tdrz") || name.start_with?("silero-")
next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-")
if matched = name.match(/\A(?<name>.*)-q\d_\d\z/)
name = matched[:name]

View File

@ -0,0 +1,74 @@
module Whisper
module Output
module Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
module Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end
end

View File

@ -1,58 +0,0 @@
module Whisper
class Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end

View File

@ -40,7 +40,21 @@ module Whisper
def self.log_set: (log_callback?, Object? user_data) -> log_callback
def self.system_info_str: () -> String
module Output
module Context
def to_srt: () -> String
def to_webvtt: () -> String
end
module Segment
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
end
end
class Context
include Output::Context
def self.new: (String | path | ::URI::HTTP) -> instance
# transcribe a single file
@ -139,17 +153,14 @@ module Whisper
| (Whisper::Params, _Samples, ?Integer n_samples) -> self
| (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
def to_srt: () -> String
def to_webvtt: () -> String
class Params
def self.new: (
use_gpu: boolish,
flash_attn: boolish,
gpu_device: Integer,
dtw_token_timestamps: boolish,
dtw_aheads_preset: Integer,
dtw_n_top: Integer | nil,
?use_gpu: boolish,
?flash_attn: boolish,
?gpu_device: Integer,
?dtw_token_timestamps: boolish,
?dtw_aheads_preset: Integer,
?dtw_n_top: Integer | nil,
) -> instance
def use_gpu=: (boolish) -> boolish
@ -444,6 +455,9 @@ module Whisper
def abort_on: { (Object user_data) -> boolish } -> void
end
module LogSettable
end
class Model
def self.pre_converted_models: () -> Hash[String, Model::URI]
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
@ -474,6 +488,8 @@ module Whisper
end
class Segment
include Output::Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
@ -514,9 +530,6 @@ module Whisper
#
def each_token: { (Token) -> void } -> void
| () -> Enumerator[Token]
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
# Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next`
#
@ -528,7 +541,7 @@ module Whisper
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys
end
module Token
class Token
type deconstructed_keys = {
id: (Integer | nil),
tid: (Integer | nil),
@ -598,6 +611,336 @@ module Whisper
def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys
end
module Parakeet
extend LogSettable
VERSION: String
# Control logging output. The default behavior is to print to stderr.
#
def self.log_set: (nil, Object? user_data) -> nil
| (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil
def self.system_info_str: () -> String
class Context
include Output::Context
# Load a Parakeet model from the given file path.
#
def self.new: (String | path | ::URI::HTTP, ?Params) -> instance
# Transcribe a single audio file.
#
def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
# Not thread safe for the same context.
#
# The second argument `samples` must be an array of samples, respond to `:length`,
# or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
#
def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self
| (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self
# Number of generated text segments.
#
def full_n_segments: () -> Integer
# Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer segment_index) -> Integer
# End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer segment_index) -> Integer
# Text of a segment indexed by `segment_index`.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer segment_index) -> String
# Number of tokens in the segment indexed by `segment_index`.
#
def full_n_tokens: (Integer segment_index) -> Integer
# Text of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_text: (Integer segment_index, Integer token_index) -> String
# Token id of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer
# Probability of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_p: (Integer segment_index, Integer token_index) -> Float
# Token data of the token indexed by `token_index` in the segment indexed by `segment_index`.
#
def full_get_token_data: (Integer segment_index, Integer token_index) -> Token
def model: () -> Model
# Yields each Whisper::Parakeet::Segment:
#
# parakeet.transcribe("path/to/audio.wav", params)
# parakeet.each_segment do |segment|
# puts segment.text
# end
#
# Returns an `Enumerator` if no block given:
#
# parakeet.transcribe("path/to/audio.wav", params)
# enum = parakeet.each_segment
# enum.to_a # => [#<Whisper::Parakeet::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
class Params
def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance
def use_gpu: () -> boolish
def use_gpu=: (boolish) -> boolish
def gpu_device: () -> Integer
def gpu_device=: (Integer) -> Integer
end
end
class Params
def self.new: (
?n_threads: Integer,
?offset_ms: Integer,
?duration_ms: Integer,
?no_context: boolish,
?audio_ctx: Integer,
?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void,
?new_segment_callback_user_data: Object,
?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void,
?new_token_callback_user_data: Object,
?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void,
?progress_callback_user_data: Object,
?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish,
?encoder_begin_callback_user_data: Object,
?abort_callback: ^(Object user_data) -> boolish,
?abort_callback_user_data: Object
) -> instance
# Number of threads to use.
#
def n_threads=: (Integer) -> Integer
def n_threads: () -> Integer
# Start offset in ms.
#
def offset_ms=: (Integer) -> Integer
def offset_ms: () -> Integer
# Audio duration to process in ms.
#
def duration_ms=: (Integer) -> Integer
def duration_ms: () -> Integer
# If `true`, does not use past transcription (if any) as context.
#
def no_context=: (boolish) -> boolish
def no_context: () -> (true | false)
# Overwrite the audio context size. `0` uses the default value.
#
def audio_ctx=: (Integer) -> Integer
def audio_ctx: () -> Integer
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void)
def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of new segment callback.
#
def new_segment_callback_user_data=: (Object?) -> Object?
def new_segment_callback_user_data: () -> Object?
# Sets token callback, called for every newly predicted token.
#
def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void)
def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of token callback.
#
def new_token_callback_user_data=: (Object?) -> Object?
def new_token_callback_user_data: () -> Object?
# Sets progress callback, called on each progress update.
#
# +progress+ is an Integer between 0 and 100.
#
def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void)
def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil)
# Sets user data passed to the last argument of progress callback.
#
def progress_callback_user_data=: (Object?) -> Object?
def progress_callback_user_data: () -> Object?
# Sets encoder begin callback, called each time before the encoder starts.
#
# If it returns `false`, the computation is aborted.
#
def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish)
def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object?) -> Object?
def encoder_begin_callback_user_data: () -> Object?
# Sets abort callback, called each time before ggml computation starts.
#
def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish)
def abort_callback: () -> ((^(Object user_data) -> boolish) | nil)
# Sets user data passed to the last argument of abort callback.
#
def abort_callback_user_data=: (Object?) -> Object?
def abort_callback_user_data: () -> Object?
# Hook called on new segment. Yields each Whisper::Parakeet::Segment.
#
def on_new_segment: { (Segment) -> void } -> void
# Hook called on new token. Yields each Whisper::Parakeet::Token.
#
def on_new_token: { (Token) -> void } -> void
# Hook called on progress update. Yields each progress `Integer` between 0 and 100.
#
def on_progress: { (Integer progress) -> void } -> void
# Hook called each time before the encoder starts.
#
def on_encoder_begin: { () -> boolish } -> void
# Call block to determine whether abort or not. Return `true` when you want to abort.
#
def abort_on: { () -> boolish } -> void
end
class Segment
include Output::Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
text: (String | nil)
}
# Start time in milliseconds.
#
def start_time: () -> Integer
# End time in milliseconds.
#
def end_time: () -> Integer
# Text of the segment.
#
def text: () -> String
# Yields each Whisper::Parakeet::Token:
#
# parakeet.each_segment.first.each_token do |token|
# p token
# end
#
# Returns an `Enumerator` if no block is given:
#
# parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...]
#
def each_token: { (Token) -> void } -> void
| () -> Enumerator[Token]
# Possible keys: `:start_time`, `:end_time`, `:text`
#
def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys
end
class Token
type deconstructed_keys = {
id: (Integer | nil),
duration_idx: (Integer | nil),
duration_value: (Integer | nil),
frame_index: (Integer | nil),
probability: (Float | nil),
log_probability: (Float | nil),
start_time: (Integer | nil),
end_time: (Integer | nil),
word_start: ((true | false) | nil),
text: (String | nil),
}
# Token ID.
#
def id: () -> Integer
# Index into the model's durations array.
#
def duration_idx: () -> Integer
# Actual duration value.
#
def duration_value: () -> Integer
# Frame index of the token.
#
def frame_index: () -> Integer
# Probability of the token.
#
def probability: () -> Float
# Log probability of the token.
#
def log_probability: () -> Float
# Start time of the token in milliseconds.
#
def start_time: () -> Integer
# End time of the token in milliseconds.
#
def end_time: () -> Integer
# Whether this token is the start of a word.
#
def word_start?: () -> (true | false)
# Get the token text of the token.
#
def text: () -> String
def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys
end
class Model
def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer
def n_audio_state: () -> Integer
def n_audio_head: () -> Integer
def n_audio_layer: () -> Integer
def n_mels: () -> Integer
def ftype: () -> Integer
end
end
module VAD
class Params
def self.new: (

View File

@ -5,6 +5,8 @@ require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
Parakeet = Whisper::Parakeet
class << self
def whisper
return @whisper if @whisper

View File

@ -129,6 +129,7 @@ class TestCallback < TestBase
return false
}
@whisper.transcribe(@audio, @params)
sleep 0.5 # wait for logs dequeued
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end

View File

@ -0,0 +1,28 @@
require_relative "helper"
require "stringio"
class TestParakeet < TestBase
def test_log_set
log_callback = Parakeet.instance_variable_get("@log_callback")
user_data = Parakeet.instance_variable_get("@log_callback_user_data")
$stdout = StringIO.new
Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil
Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
sleep 0.1
$stdout.rewind
logs = $stdout.string
assert_match /loading model from/, logs
ensure
$stdout = STDOUT
Parakeet.log_set log_callback, user_data
end
def test_system_info_str
assert_match /\APARAKEET : /, Parakeet.system_info_str
end
def test_version
assert_instance_of String, Parakeet::VERSION
end
end

View File

@ -0,0 +1,107 @@
require_relative "helper"
class TestParakeetCallback < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
@params = Parakeet::Params.new
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
end
def test_new_segment_callback
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_kind_of Integer, n_new
assert n_new > 0
assert_same @parakeet, context
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
start_time = context.full_get_segment_t0(i_segment) * 10
end_time = context.full_get_segment_t1(i_segment) * 10
text = context.full_get_segment_text(i_segment)
assert_kind_of Integer, start_time
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0
end
}
@parakeet.transcribe AUDIO, @params
end
def test_on_new_segment
seg = nil
index = 0
@params.on_new_segment do |segment|
assert_instance_of Parakeet::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text)
end
index += 1
end
@parakeet.transcribe AUDIO, @params
assert_equal 0, seg.start_time
assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text
end
def test_on_new_token
index = 0
@params.on_new_token do |token|
assert_instance_of Parakeet::Token, token
if index == 0
assert_instance_of Integer, token.start_time
assert_match "▁And", token.text
end
index += 1
end
@parakeet.transcribe AUDIO, @params
end
def test_on_progress
first = nil
@params.on_progress do |progress|
assert_kind_of Integer, progress
assert 0 <= progress && progress <= 100
first = progress if first.nil?
end
@parakeet.transcribe AUDIO, @params
assert_equal 0, first
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@parakeet.transcribe AUDIO, @params
assert i > 0
end
def test_abort_on
do_abort = false
@params.on_new_segment do |segment|
do_abort = true if segment.text.match?(/ask/)
end
i = 0
@params.abort_on do
i += 1
do_abort
end
@parakeet.transcribe(AUDIO, @params) rescue nil
assert i > 0
end
end

View File

@ -0,0 +1,116 @@
require_relative "helper"
require "stringio"
class TestParakeetContext < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
@params = Parakeet::Params.new
end
def test_new
assert_instance_of Parakeet::Context, @parakeet
end
def test_new_with_params
log_callback = Parakeet.instance_variable_get(:@log_callback)
user_data = Parakeet.instance_variable_get(:@log_callback_user_data)
begin
logs = ""
Parakeet.log_set proc {|level, message| logs << message}, nil
params = Parakeet::Context::Params.new(use_gpu: false)
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params)
assert_instance_of Parakeet::Context, parakeet
assert_match /use gpu\s+=\s+0/, logs
ensure
Parakeet.log_set log_callback, user_data
end
end
sub_test_case "full" do
def setup
super
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
end
def test_full
@parakeet.full @params, @samples, @samples.length
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text
end
def test_full_without_length
@parakeet.full(@params, @samples)
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_enumerator
samples = @samples.each
@parakeet.full @params, samples, @samples.length
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_enumerator_without_length
samples = @samples.each
assert_raise ArgumentError do
@parakeet.full @params, samples
end
end
def test_full_enumerator_with_too_large_length
samples = @samples.each.take(10).to_enum
assert_raise StopIteration do
@parakeet.full @params, samples, 11
end
end
def test_full_with_memory_view
samples = JFKReader.new(AUDIO)
@parakeet.full @params, samples
segments = @parakeet.each_segment.to_a
assert_equal 1, segments.length
assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text
end
def test_full_with_memroy_view_gc
samples = JFKReader.new(AUDIO)
@parakeet.full(@params, samples)
GC.start
require "fiddle"
Fiddle::MemoryView.export samples do |view|
assert_equal 176000, view.to_s.unpack("#{view.format}*").length
end
end
end
def test_transcribe
assert_nothing_raised do
@parakeet.transcribe AUDIO, @params
end
end
def test_transcribe_with_pathname
assert_nothing_raised do
@parakeet.transcribe Pathname(AUDIO), @params
end
end
def test_transcribe_with_nothing
assert_raise_message(/open/) do
@parakeet.transcribe "nothing", @params
end
end
end

View File

@ -0,0 +1,24 @@
require_relative "helper"
class TestParakeetContextParams < TestBase
def setup
@params = Parakeet::Context::Params.new
end
def test_new
assert_instance_of Parakeet::Context::Params, @params
end
def test_attributes
assert_true @params.use_gpu
assert_instance_of Integer, @params.gpu_device
end
def test_attribute_writer
@params.use_gpu = false
assert_false @params.use_gpu
@params.gpu_device = 2
assert_equal 2, @params.gpu_device
end
end

View File

@ -0,0 +1,21 @@
require_relative "helper"
class TestParakeetModel < TestBase
def test_model
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
assert_instance_of Parakeet::Model, parakeet.model
end
def test_attributes
parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin")
model = parakeet.model
assert_equal 10, model.n_vocab
assert_equal 3200, model.n_audio_ctx
assert_equal 8, model.n_audio_state
assert_equal 2, model.n_audio_head
assert_equal 1, model.n_audio_layer
assert_equal 16, model.n_mels
assert_equal 0, model.ftype
end
end

View File

@ -0,0 +1,78 @@
require_relative "helper"
require "etc"
class TestParakeetParams < TestBase
PARAM_NAMES = [
:n_threads,
:offset_ms,
:duration_ms,
:no_context,
:audio_ctx
]
def setup
@params = Parakeet::Params.new
end
def test_new
assert_instance_of Parakeet::Params, @params
end
def test_n_threads
assert_equal [4, Etc.nprocessors].min, @params.n_threads
@params.n_threads = 1
assert_equal 1, @params.n_threads
end
def test_offset_ms
assert_equal 0, @params.offset_ms
@params.offset_ms = 10_000
assert_equal 10_000, @params.offset_ms
end
def test_duration_ms
assert_equal 0, @params.duration_ms
@params.duration_ms = 60_000
assert_equal 60_000, @params.duration_ms
end
def test_no_context
assert_equal true, @params.no_context
@params.no_context = false
assert_equal false, @params.no_context
end
def test_audio_ctx
assert_equal 0, @params.audio_ctx
@params.audio_ctx = 1
assert_equal 1, @params.audio_ctx
end
def test_new_with_kw_args
params = Parakeet::Params.new(n_threads: 1)
assert_equal 1, params.n_threads
assert_equal 0, params.offset_ms
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_value = @params.send(param)
value = case [param, default_value]
in [*, true | false]
!default_value
in [*, Integer]
default_value + 1
end
params = Parakeet::Params.new(param => value)
assert_equal value, params.send(param)
PARAM_NAMES.reject {|name| name == param}.each do |name|
assert_equal @params.send(name), params.send(name)
end
end
end

View File

@ -0,0 +1,42 @@
require_relative "helper"
class TestParakeetSegment < TestBase
def setup
omit "Skip not to download large model" if ENV["CI"]
@parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
@parakeet.transcribe AUDIO, Parakeet::Params.new
end
def test_segment
whole_text = ""
@parakeet.each_segment do |segment|
assert_instance_of Parakeet::Segment, segment
assert_kind_of Integer, segment.start_time
assert segment.end_time >= segment.start_time
assert_kind_of String, segment.text
whole_text << segment.text
end
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text)
end
def test_deconstruct_keys
segment = @parakeet.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text
}
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text])
end
def test_deconstruct_keys_with_nil
segment = @parakeet.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text
}
assert_equal expected, segment.deconstruct_keys(nil)
end
end

View File

@ -0,0 +1,73 @@
require_relative "helper"
class TestParakeetToken < TestBase
ATTRS = %i[
id
duration_idx
duration_value
frame_index
probability
log_probability
start_time
end_time
word_start?
text
]
def setup
omit "Skip not to download large model" if ENV["CI"]
Whisper.instance_variable_set "@whisper", nil
GC.start
parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0")
params = Parakeet::Params.new
parakeet.transcribe AUDIO, params
@segment = parakeet.each_segment.first
end
def test_each_token
i = 0
@segment.each_token do |token|
i += 1
assert_instance_of Parakeet::Token, token
end
assert_equal 38, i
end
def test_each_token_without_block
assert_instance_of Enumerator, @segment.each_token
end
def test_token
token = @segment.each_token.first
assert_instance_of Parakeet::Token, token
assert_instance_of Integer, token.id
assert_instance_of Integer, token.duration_idx
assert_instance_of Integer, token.duration_value
assert_instance_of Integer, token.frame_index
assert_instance_of Float, token.probability
assert_instance_of Float, token.log_probability
assert_instance_of Integer, token.start_time
assert_instance_of Integer, token.end_time
assert_instance_of String, token.text
end
def test_text
assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."],
@segment.each_token.collect(&:text)
end
def test_deconstruct_keys_with_nil
token = @segment.each_token.first
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
assert_equal expected, token.deconstruct_keys(nil)
end
def test_deconstruct_keys_with_keys
token = @segment.each_token.first
expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h
assert_equal expected, token.deconstruct_keys(expected.keys)
end
end

View File

@ -9,7 +9,7 @@ class TestVADSegment < TestBase
end
assert_raise do
segments.end_time
segment.end_time
end
assert_raise do

View File

@ -149,6 +149,7 @@ class TestWhisper < TestBase
}
Whisper.log_set log_callback, user_data
Whisper::Context.new("base.en")
sleep 0.1 # wait for logs dequeued
assert logs.length > 30
logs.each do |log|

View File

@ -23,7 +23,7 @@ Gem::Specification.new do |s|
s.test_files = s.files.select {|file| file.start_with? "test/"}
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'
s.required_ruby_version = '>= 3.3.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggml-org/whisper.cpp'

View File

@ -0,0 +1,30 @@
set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@)
set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@)
set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@)
set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@)
@PACKAGE_INIT@
set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@")
set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@")
set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@")
find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake)
find_library(parakeet_LIBRARY parakeet
REQUIRED
HINTS ${PARAKEET_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH
)
add_library(parakeet UNKNOWN IMPORTED)
set_target_properties(parakeet
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${parakeet_LIBRARY}"
INTERFACE_COMPILE_FEATURES cxx_std_11
POSITION_INDEPENDENT_CODE ON)
check_required_components(parakeet)

10
cmake/parakeet.pc.in Normal file
View File

@ -0,0 +1,10 @@
prefix=@CMAKE_INSTALL_PREFIX@
exec_prefix=${prefix}
libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
includedir=${prefix}/include
Name: parakeet
Description: Port of NVIDIA's Parakeet model in C/C++
Version: @PROJECT_VERSION@
Libs: -L${libdir} -lggml -lggml-base -lparakeet
Cflags: -I${includedir}

View File

@ -107,6 +107,8 @@ else()
add_subdirectory(server)
add_subdirectory(quantize)
add_subdirectory(vad-speech-segments)
add_subdirectory(parakeet-cli)
add_subdirectory(parakeet-quantize)
if (WHISPER_SDL2)
add_subdirectory(stream)
add_subdirectory(command)

View File

@ -0,0 +1,8 @@
set(TARGET parakeet-cli)
add_executable(${TARGET} parakeet-cli.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -0,0 +1,106 @@
# whisper.cpp/examples/parakeet-cli
This is an example of using the [Parakeet] model in whisper.cpp.
### Download converted model
```console
$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models
```
### Building
```console
$ cmake -B build -S .
$ cmake --build build --target parakeet-cli -j 12
```
### Usage
```console
$ ./build/bin/parakeet-cli --help
usage: ./build/bin/parakeet-cli [options] file0 file1 ...
supported audio formats: flac, mp3, ogg, wav
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path
-f, --file FILE [ ] input audio file
-ng, --no-gpu [false ] disable GPU
-dev N, --device N [0 ] GPU device to use
-ps, --print-segments [false ] print segment information
```
### Example
```console
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav
Processing audio (176000 samples, 11.00 seconds)
Processing audio: total_frames=1101, chunk_size=1101
parakeet_decode: starting decode with n_frames=138
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
```
To print segment information:
```console
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments
Processing audio (176000 samples, 11.00 seconds)
Processing audio: total_frames=1101, chunk_size=1101
parakeet_decode: starting decode with n_frames=138
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
Segments (1):
Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country."
Tokens [38]:
[ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And"
[ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so"
[ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false ","
[ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my"
[ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f"
[ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell"
[ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow"
[ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer"
[ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic"
[ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans"
[10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false ","
[11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a"
[12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk"
[13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not"
[14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what"
[15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your"
[16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co"
[17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un"
[18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr"
[19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y"
[20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can"
[21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do"
[22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for"
[23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you"
[24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false ","
[25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a"
[26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk"
[27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what"
[28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you"
[29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can"
[30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do"
[31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for"
[32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your"
[33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co"
[34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un"
[35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr"
[36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y"
[37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "."
```
### Model conversion
Clone the original model from Hugging Face:
```console
$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
```
Convert the model:
```console
(venv) $ python models/convert-parakeet-to-ggml.py \
--model <path to cloned model> \
--out-dir models \
--out-name ggml-parakeet-tdt-0.6b-v3-f16.bin
```
[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3

View File

@ -0,0 +1,243 @@
#include "parakeet.h"
#include "common-whisper.h"
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#include <fstream>
// command-line parameters
struct parakeet_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
bool use_gpu = true;
int32_t gpu_device = 0;
bool print_segments = false;
bool output_txt = false;
bool no_prints = false;
std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin";
std::string output_file = "";
std::vector<std::string> fname_inp = {};
};
static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params);
static char * requires_value_error(const std::string & arg) {
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
exit(1);
}
static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) {
if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) {
params.gpu_device = std::stoi(env_device);
}
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;
}
if (arg == "-h" || arg == "--help") {
parakeet_print_usage(argc, argv, params);
exit(0);
}
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); }
else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
parakeet_print_usage(argc, argv, params);
exit(1);
}
}
return true;
}
static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]);
fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", "");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device);
fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, "\n");
}
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
bool * is_first = (bool *) user_data;
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
char text_buf[256];
parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf));
printf("%s", text_buf);
fflush(stdout);
*is_first = false;
}
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int main(int argc, char ** argv) {
ggml_backend_load_all();
parakeet_params params;
if (parakeet_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.no_prints) {
parakeet_log_set(cb_log_disable, NULL);
}
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
parakeet_print_usage(argc, argv, params);
return 1;
}
struct parakeet_context_params ctx_params = parakeet_context_default_params();
ctx_params.use_gpu = params.use_gpu;
ctx_params.gpu_device = params.gpu_device;
if (!params.no_prints) {
fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str());
}
struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params);
if (pctx == nullptr) {
fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str());
return 1;
}
if (!params.no_prints) {
fprintf(stderr, "Successfully loaded Parakeet model\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info());
}
// Process each input file
for (const auto & fname : params.fname_inp) {
if (!params.no_prints) {
fprintf(stderr, "\nProcessing file: %s\n", fname.c_str());
}
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str());
continue;
}
if (pcmf32.empty()) {
fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str());
continue;
}
bool is_first = true;
struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
full_params.n_threads = params.n_threads;
full_params.new_token_callback = token_callback;
full_params.new_token_callback_user_data = &is_first;
const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH);
int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size());
if (ret != 0) {
fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str());
continue;
}
printf("\n");
if (params.output_txt) {
const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt";
std::ofstream fout(fname_out);
if (fout.is_open()) {
const int n_segments = parakeet_full_n_segments(pctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = parakeet_full_get_segment_text(pctx, i);
fout << text << "\n";
}
fout.close();
if (!params.no_prints) {
fprintf(stderr, "Output written to: %s\n", fname_out.c_str());
}
} else {
fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str());
}
}
if (!params.no_prints) {
parakeet_print_timings(pctx);
}
if (params.print_segments) {
const int n_segments = parakeet_full_n_segments(pctx);
fprintf(stderr, "\nSegments (%d):\n", n_segments);
for (int i = 0; i < n_segments; i++) {
const char * text = parakeet_full_get_segment_text(pctx, i);
const int64_t t0 = parakeet_full_get_segment_t0(pctx, i);
const int64_t t1 = parakeet_full_get_segment_t1(pctx, i);
const int n_tokens = parakeet_full_n_tokens(pctx, i);
fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
fprintf(stderr, "Tokens [%d]:\n", n_tokens);
for (int j = 0; j < n_tokens; j++) {
parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j);
const char * token_str = parakeet_token_to_str(pctx, token_data.id);
fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n",
j,
token_data.id,
token_data.frame_index,
token_data.duration_idx,
token_data.duration_value,
token_data.p,
token_data.plog,
(long long)token_data.t0,
(long long)token_data.t1,
token_data.is_word_start ? "true": "false",
token_str);
}
}
}
}
parakeet_free(pctx);
return 0;
}

View File

@ -0,0 +1,7 @@
set(TARGET parakeet-quantize)
add_executable(${TARGET} parakeet-quantize.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -0,0 +1,230 @@
#include "ggml.h"
#include "ggml-backend.h"
#include "common-ggml.h"
#include <cassert>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <string>
#include <vector>
struct parakeet_hparams {
int32_t n_vocab = 0;
int32_t n_audio_ctx = 0;
int32_t n_audio_state = 0;
int32_t n_audio_head = 0;
int32_t n_audio_layer = 0;
int32_t n_mels = 0;
int32_t ftype = 0;
int32_t n_fft = 0;
int32_t subsampling_factor = 0;
int32_t n_subsampling_channels = 0;
int32_t n_conv_kernel = 0;
int32_t n_pred_dim = 0;
int32_t n_pred_layers = 0;
int32_t n_tdt_durations = 0;
int32_t n_max_tokens = 0;
};
static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
auto finp = std::ifstream(fname_inp, std::ios::binary);
if (!finp) {
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
return false;
}
auto fout = std::ofstream(fname_out, std::ios::binary);
if (!fout) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
// magic
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__);
return false;
}
fout.write((char *) &magic, sizeof(magic));
}
// hparams
parakeet_hparams hparams;
{
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft));
finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((char *) &ftype_dst, sizeof(ftype_dst));
fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft));
fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
}
// mel filterbank
{
int32_t n_mel, n_fb;
finp.read((char *) &n_mel, sizeof(n_mel));
fout.write((char *) &n_mel, sizeof(n_mel));
finp.read((char *) &n_fb, sizeof(n_fb));
fout.write((char *) &n_fb, sizeof(n_fb));
const size_t n = (size_t) n_mel * n_fb;
std::vector<float> buf(n);
finp.read((char *) buf.data(), n * sizeof(float));
fout.write((char *) buf.data(), n * sizeof(float));
}
// window function
{
int32_t n_window;
finp.read((char *) &n_window, sizeof(n_window));
fout.write((char *) &n_window, sizeof(n_window));
std::vector<float> buf(n_window);
finp.read((char *) buf.data(), n_window * sizeof(float));
fout.write((char *) buf.data(), n_window * sizeof(float));
}
// TDT durations
{
std::vector<uint32_t> buf(hparams.n_tdt_durations);
finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
}
// vocab
{
int32_t n_tokens;
finp.read((char *) &n_tokens, sizeof(n_tokens));
fout.write((char *) &n_tokens, sizeof(n_tokens));
for (int i = 0; i < n_tokens; ++i) {
int32_t len;
finp.read((char *) &len, sizeof(len));
fout.write((char *) &len, sizeof(len));
std::string token(len, '\0');
finp.read(&token[0], len);
fout.write(&token[0], len);
}
}
// tensors — quantize 2D weights skipping tensors that must stay F32:
// ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights.
// pos_bias_u / pos_bias_v are declared F32 in the loader.
const std::vector<std::string> to_quant = { ".*" };
std::vector<std::string> to_skip = {
// CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights)
"encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight",
// Declared F32 in loader (pos_bias tensors)
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_u",
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_v",
};
// Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant
// types (block size 256) cannot quantize 640 evenly, so keep them F32. For
// other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be
// quantized normally. The loader mirrors this logic at load time.
{
const ggml_type qtype = ggml_ftype_to_ggml_type(ftype);
const int32_t blck = ggml_blck_size(qtype);
if (blck > 1 && hparams.n_pred_dim % blck != 0) {
to_skip.push_back("decoder\\.prediction\\.embed\\.weight");
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*");
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*");
to_skip.push_back("joint\\.pred\\.weight");
to_skip.push_back("joint\\.joint_net\\.2\\.weight");
}
}
if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) {
fprintf(stderr, "%s: failed to quantize tensors\n", __func__);
return false;
}
finp.close();
fout.close();
return true;
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
ggml_print_ftypes(stderr);
return 1;
}
// initialise F16 lookup tables
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
const std::string fname_inp = argv[1];
const std::string fname_out = argv[2];
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
if (ftype == GGML_FTYPE_UNKNOWN) {
fprintf(stderr, "%s: invalid quantization type\n", argv[0]);
ggml_print_ftypes(stderr);
return 1;
}
const int64_t t_start_us = ggml_time_us();
if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str());
return 1;
}
printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f);
printf("%s: output model = %s\n", argv[0], fname_out.c_str());
return 0;
}

342
include/parakeet.h Normal file
View File

@ -0,0 +1,342 @@
#ifndef PARAKEET_H
#define PARAKEET_H
#include "ggml.h"
#include "ggml-cpu.h"
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef __GNUC__
# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
#else
# define PARAKEET_DEPRECATED(func, hint) func
#endif
#ifdef PARAKEET_SHARED
# ifdef _WIN32
# ifdef PARAKEET_BUILD
# define PARAKEET_API __declspec(dllexport)
# else
# define PARAKEET_API __declspec(dllimport)
# endif
# else
# define PARAKEET_API __attribute__ ((visibility ("default")))
# endif
#else
# define PARAKEET_API
#endif
#define PARAKEET_SAMPLE_RATE 16000
#define PARAKEET_HOP_LENGTH 160
#ifdef __cplusplus
extern "C" {
#endif
struct parakeet_context;
struct parakeet_state;
struct parakeet_full_params;
typedef int32_t parakeet_pos;
typedef int32_t parakeet_token;
typedef int32_t parakeet_seq_id;
struct parakeet_context_params {
bool use_gpu;
int gpu_device; // CUDA device
};
typedef struct parakeet_token_data {
parakeet_token id; // the BPE subword ID (0-8191)
int duration_idx; // index into the models durations array
int duration_value; // actual duration value
int frame_index;
float p;
float plog;
int64_t t0;
int64_t t1;
bool is_word_start;
} parakeet_token_data;
typedef struct parakeet_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} parakeet_model_loader;
PARAKEET_API const char * parakeet_version(void);
// Various functions for loading a ggml parakeet model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params);
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params);
PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params);
// These are the same as the above, but the internal state of the context is not allocated automatically
// It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523)
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params);
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params);
PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params);
PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx);
// Frees all allocated memory
PARAKEET_API void parakeet_free (struct parakeet_context * ctx);
PARAKEET_API void parakeet_free_state(struct parakeet_state * state);
PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params);
PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided parakeet context.
// Returns 0 on success
PARAKEET_API int parakeet_pcm_to_mel(
struct parakeet_context * ctx,
const float * samples,
int n_samples,
int n_threads);
PARAKEET_API int parakeet_pcm_to_mel_with_state(
struct parakeet_context * ctx,
struct parakeet_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context.
// Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 128
// Returns 0 on success
PARAKEET_API int parakeet_set_mel(
struct parakeet_context * ctx,
const float * data,
int n_len,
int n_mel);
PARAKEET_API int parakeet_set_mel_with_state(
struct parakeet_context * ctx,
struct parakeet_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context.
// Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
PARAKEET_API int parakeet_encode(
struct parakeet_context * ctx,
int offset,
int n_threads);
PARAKEET_API int parakeet_encode_with_state(
struct parakeet_context * ctx,
struct parakeet_state * state,
int offset,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
PARAKEET_API int parakeet_tokenize(
struct parakeet_context * ctx,
const char * text,
parakeet_token * tokens,
int n_max_tokens);
// Return the number of tokens in the provided text
// Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0)
int parakeet_token_count(struct parakeet_context * ctx, const char * text);
PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length
PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length
PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx);
PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx);
PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx);
// Token logits obtained from the last call to parakeet_full/parakeet_chunk
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx);
PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token);
PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len);
// Special tokens
PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx);
PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx);
PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx);
// Performance information from the default state.
struct parakeet_timings {
float sample_ms;
float encode_ms;
float decode_ms;
};
PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx);
PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx);
PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx);
// Print system information
PARAKEET_API const char * parakeet_print_system_info(void);
// Available sampling strategies
enum parakeet_sampling_strategy {
PARAKEET_SAMPLING_GREEDY,
};
// Token callback.
// Called for each new predicted token.
// Use the parakeet_full_...() functions to obtain the text segments
typedef void (*parakeet_new_token_callback)(
struct parakeet_context * ctx,
struct parakeet_state * state,
const parakeet_token_data * token_data,
void * user_data);
// Text segment callback
// Called on every newly generated text segment
// Use the parakeet_full_...() functions to obtain the text segments
typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data);
// Progress callback
typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data);
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data);
// Parameters for the parakeet_full() function
// If you change the order or add new parameters, make sure to update the default values in parakeet.cpp:
// parakeet_full_default_params()
struct parakeet_full_params {
enum parakeet_sampling_strategy strategy;
int n_threads;
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool no_context; // do not use past transcription (if any) as context
int audio_ctx; // overwrite the audio context size (0 = use default)
// called for every newly generated text segment
parakeet_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called for every newly generated token
parakeet_new_token_callback new_token_callback;
void * new_token_callback_user_data;
// called on each progress update
parakeet_progress_callback progress_callback;
void * progress_callback_user_data;
// called each time before the encoder starts
parakeet_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called each time before ggml computation starts
ggml_abort_callback abort_callback;
void * abort_callback_user_data;
};
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params()
PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void);
PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void);
PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy);
PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Not thread safe for same context
PARAKEET_API int parakeet_full(
struct parakeet_context * ctx,
struct parakeet_full_params params,
const float * samples,
int n_samples);
PARAKEET_API int parakeet_full_with_state(
struct parakeet_context * ctx,
struct parakeet_state * state,
struct parakeet_full_params params,
const float * samples,
int n_samples);
// Process a single chunk of audio data that fits within the model's audio context window.
// This is more efficient than parakeet_full() for short audio clips.
PARAKEET_API int parakeet_chunk(
struct parakeet_context * ctx,
struct parakeet_state * state,
struct parakeet_full_params params,
const float * samples,
int n_samples);
// Number of generated text segments
PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx);
PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state);
// Get the start and end time of the specified segment
PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment);
PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment);
PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment);
PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment);
// Get the text of the specified segment
PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment);
PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment);
// Get number of tokens in the specified segment
PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment);
PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment);
// Get the token text of the specified token in the specified segment
PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token);
PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token);
// Get the token id of the specified token in the specified segment
PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token);
PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment
PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token);
PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment
PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token);
PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token);
// Control logging output; default behavior is to print to stderr
PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Convert Parakeet TDT model from NeMo format to ggml format
#
# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32]
#
# The NeMo file is a tar archive containing:
# - model_weights.ckpt (PyTorch checkpoint)
# - model_config.yaml (model configuration)
# - tokenizer files
#
# This script extracts the NeMo archive, loads the model weights and configuration,
# and saves them in ggml format compatible with whisper.cpp.
#
import torch
import argparse
import io
import os
import sys
import struct
import tarfile
import tempfile
import shutil
import yaml
import numpy as np
from pathlib import Path
from typing import Optional
def hz_to_mel(freq):
return 2595.0 * np.log10(1.0 + freq / 700.0)
def mel_to_hz(mel):
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
def extract_nemo_archive(nemo_path, extract_dir):
print(f"Extracting {nemo_path} to {extract_dir}")
with tarfile.open(nemo_path, 'r') as tar:
tar.extractall(path=extract_dir)
print("Extraction complete")
def load_model_config(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def load_tokenizer(extract_dir, config):
tokenizer_model_path = None
tokenizer_vocab_path = None
for file in os.listdir(extract_dir):
if file.endswith('_tokenizer.model'):
tokenizer_model_path = os.path.join(extract_dir, file)
elif file.endswith('tokenizer.vocab'):
tokenizer_vocab_path = os.path.join(extract_dir, file)
if not tokenizer_model_path:
raise FileNotFoundError("Tokenizer model file not found")
if not tokenizer_vocab_path:
raise FileNotFoundError("Tokenizer vocab file not found")
tokens = {}
with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
parts = line.strip().split('\t')
if len(parts) >= 1:
token = parts[0]
tokens[token.encode('utf-8')] = idx
print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}")
if len(tokens) != 8192:
print(f"WARNING: Expected 8192 tokens, got {len(tokens)}")
return tokens
def write_tensor(fout, name, data, use_f16=True, force_f32=False):
if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1:
data = data.reshape(1, -1, 1, 1)
print(f" Reshaped conv bias {name} to {data.shape}")
n_dims = len(data.shape)
ftype = 1 if use_f16 and not force_f32 else 0
if force_f32:
data = data.astype(np.float32)
elif use_f16:
if n_dims < 2 or 'bias' in name or 'norm' in name or \
('pre_encode.conv' in name and n_dims == 4) or \
'depthwise_conv.weight' in name:
data = data.astype(np.float32)
ftype = 0
else:
data = data.astype(np.float16)
else:
data = data.astype(np.float32)
dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)]
print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}")
name_bytes = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(name_bytes)
data.tofile(fout)
def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None):
nemo_path = Path(nemo_path)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Create temporary directory for extraction
with tempfile.TemporaryDirectory() as temp_dir:
extract_nemo_archive(nemo_path, temp_dir)
config_path = os.path.join(temp_dir, 'model_config.yaml')
config = load_model_config(config_path)
print("Model configuration:")
print(f" Sample rate: {config['sample_rate']}")
print(f" Encoder layers: {config['encoder']['n_layers']}")
print(f" Encoder d_model: {config['encoder']['d_model']}")
print(f" Mel features: {config['preprocessor']['features']}")
weights_path = os.path.join(temp_dir, 'model_weights.ckpt')
print(f"\nLoading model weights from {weights_path}")
checkpoint = torch.load(weights_path, map_location='cpu')
# Extract state dict
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
print(f"Loaded {len(state_dict)} tensors")
# Load tokenizer
print("\nLoading tokenizer...")
tokens = load_tokenizer(temp_dir, config)
print(f"Loaded {len(tokens)} tokens")
# Prepare hyperparameters for the Parakeet ggml format.
hparams = {
'n_audio_ctx': 5000,
'n_audio_state': config['encoder']['d_model'],
'n_audio_head': config['encoder']['n_heads'],
'n_audio_layer': config['encoder']['n_layers'],
'n_mels': config['preprocessor']['features'],
'n_fft': config['preprocessor']['n_fft'],
'subsampling_factor': config['encoder']['subsampling_factor'],
'n_subsampling_channels': config['encoder']['subsampling_conv_channels'],
'n_conv_kernel': config['encoder']['conv_kernel_size'],
'n_pred_dim': config['decoder']['prednet']['pred_hidden'],
'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'],
'n_vocab': config['decoder']['vocab_size'],
'n_tdt_durations': config['model_defaults']['num_tdt_durations'],
'n_max_tokens': config['decoding']['greedy']['max_symbols'],
}
print("\nGGML hyperparameters:")
for key, value in hparams.items():
print(f" {key}: {value}")
# Create output file
if out_name:
fname_out = output_dir / out_name
else:
fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin")
print(f"\nWriting to {fname_out}")
with open(fname_out, 'wb') as fout:
# Write magic number
fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex
# Write hyperparameters
fout.write(struct.pack("i", hparams['n_vocab']))
fout.write(struct.pack("i", hparams['n_audio_ctx']))
fout.write(struct.pack("i", hparams['n_audio_state']))
fout.write(struct.pack("i", hparams['n_audio_head']))
fout.write(struct.pack("i", hparams['n_audio_layer']))
fout.write(struct.pack("i", hparams['n_mels']))
fout.write(struct.pack("i", 1 if use_f16 else 0))
fout.write(struct.pack("i", hparams['n_fft']))
fout.write(struct.pack("i", hparams['subsampling_factor']))
fout.write(struct.pack("i", hparams['n_subsampling_channels']))
fout.write(struct.pack("i", hparams['n_conv_kernel']))
fout.write(struct.pack("i", hparams['n_pred_dim']))
fout.write(struct.pack("i", hparams['n_pred_layers']))
fout.write(struct.pack("i", hparams['n_tdt_durations']))
fout.write(struct.pack("i", hparams['n_max_tokens']))
# Extract mel filterbank from model
fb_key = None
for key in state_dict.keys():
if 'featurizer.fb' in key or 'filterbank' in key.lower():
fb_key = key
break
if not fb_key:
print("\nERROR: Mel filterbank not found in model!")
print("Expected tensor with 'featurizer.fb' or 'filterbank' in name")
print("\nAvailable preprocessor tensors:")
for key in sorted(state_dict.keys()):
if 'preprocessor' in key or 'featurizer' in key:
print(f" {key}: {state_dict[key].shape}")
raise ValueError("Mel filterbank tensor not found in model")
print(f"\nUsing model's mel filterbank from: {fb_key}")
mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32)
print(f" Filterbank shape: {mel_filters.shape}")
print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}")
print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}")
print(f" First row sum: {mel_filters[0].sum():.6f}")
if len(mel_filters.shape) != 2:
raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}")
n_mels, n_freqs = mel_filters.shape
fout.write(struct.pack("i", n_mels)) # n_mel
fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins)
# Write mel filterbank
for i in range(n_mels):
for j in range(n_freqs):
fout.write(struct.pack("f", mel_filters[i, j]))
# Extract window function from model
window_key = None
for key in state_dict.keys():
if 'featurizer.window' in key or 'preproc' in key and 'window' in key:
window_key = key
break
if not window_key:
print("\nERROR: Window function not found in model!")
print("Expected tensor with 'featurizer.window' in name")
raise ValueError("Window function tensor not found in model")
print(f"\nUsing model's window function from: {window_key}")
window = state_dict[window_key].squeeze().numpy().astype(np.float32)
print(f" Window shape: {window.shape}")
print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}")
print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}")
print(f" Window sum: {window.sum():.6f}")
if len(window.shape) != 1:
raise ValueError(f"Expected 1D window, got shape {window.shape}")
n_window = window.shape[0]
fout.write(struct.pack("i", n_window))
# Write window function
for i in range(n_window):
fout.write(struct.pack("f", window[i]))
# Write TDT durations
tdt_durations = config['model_defaults']['tdt_durations']
if len(tdt_durations) != hparams['n_tdt_durations']:
raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}")
for duration in tdt_durations:
fout.write(struct.pack("I", duration))
fout.write(struct.pack("i", len(tokens)))
for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]):
fout.write(struct.pack("i", len(token_bytes)))
fout.write(token_bytes)
# Pre-collect prediction LSTM input-hidden biases so they can be
# folded into the hidden-hidden bias during the main write loop.
lstm_prefix = 'decoder.prediction.dec_rnn.lstm'
pred_bias_ih = {}
for key, t in state_dict.items():
if f'{lstm_prefix}.bias_ih_l' in key:
layer_idx = int(key.rsplit('bias_ih_l', 1)[1])
pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32)
print("\nConverting model weights...")
for name, tensor in state_dict.items():
# Skip the filterbank and window - already written in preprocessing section
if name == fb_key:
continue
if name == window_key:
continue
# bias_ih is folded into bias_hh below; skip writing it separately
if f'{lstm_prefix}.bias_ih_l' in name:
continue
# Don't squeeze Conv2d weights - they need to preserve all 4 dimensions
if 'conv' in name and 'weight' in name and len(tensor.shape) == 4:
data = tensor.numpy()
else:
data = tensor.squeeze().numpy()
# For prediction LSTM weights/biases:
# Fold bias_ih into bias_hh (bias_ih already skipped above).
# Reorder gates (input, forget, cell, output) from PyTorch layout
# [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs
# (i, f, o) are contiguous.
if name.startswith(f'{lstm_prefix}.'):
if f'{lstm_prefix}.bias_hh_l' in name:
layer_idx = int(name.rsplit('bias_hh_l', 1)[1])
data = data.astype(np.float32) + pred_bias_ih[layer_idx]
name = name.replace('bias_hh_l', 'bias_h_l')
h = data.shape[0] // 4
data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
write_tensor(fout, name, data, use_f16=use_f16)
print(f"\nConversion complete!")
print(f"Output file: {fname_out}")
print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Convert Parakeet TDT model from NeMo format to ggml format'
)
parser.add_argument('--model', type=str, required=True,
help='Path to Parakeet .nemo model file')
parser.add_argument('--out-dir', type=str, required=True,
help='Directory to write ggml model file')
parser.add_argument('--use-f32', action='store_true', default=False,
help='Use f32 instead of f16 (default: f16)')
parser.add_argument('--out-name', type=str, default=None,
help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)')
args = parser.parse_args()
if not os.path.exists(args.model):
print(f"Error: {args.model} not found")
sys.exit(1)
use_f16 = not args.use_f32
convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name)

Binary file not shown.

View File

@ -0,0 +1,182 @@
#!/usr/bin/env python3
import struct
import sys
import numpy as np
from pathlib import Path
def write_tensor(fout, name, data):
n_dims = len(data.shape)
data = data.astype(np.float32)
ftype = 0 # GGML_TYPE_F32
name_bytes = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(name_bytes)
data.tofile(fout)
def generate(output_path):
rng = np.random.default_rng(42)
hparams = {
'n_vocab': 10,
'n_audio_ctx': 3200,
'n_audio_state': 8,
'n_audio_head': 2,
'n_audio_layer': 1,
'n_mels': 16,
'ftype': 0,
'n_fft': 64,
'subsampling_factor': 8,
'n_subsampling_channels': 4,
'n_conv_kernel': 3,
'n_pred_dim': 8,
'n_pred_layers': 1,
'n_tdt_durations': 2,
'n_max_tokens': 5,
}
n_vocab = hparams['n_vocab']
n_state = hparams['n_audio_state']
n_head = hparams['n_audio_head']
n_layer = hparams['n_audio_layer']
n_mels = hparams['n_mels']
n_fft = hparams['n_fft']
n_sub_fac = hparams['subsampling_factor']
n_sub_ch = hparams['n_subsampling_channels']
n_conv_ker = hparams['n_conv_kernel']
dec_dim = hparams['n_pred_dim']
n_pred_l = hparams['n_pred_layers']
n_tdt = hparams['n_tdt_durations']
n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch
n_head_dim = n_state // n_head
n_pred_embed = n_vocab + 1
n_lstm_gates = 4 * dec_dim
n_joint_out = n_vocab + n_tdt + 1
n_freqs = n_fft // 2 + 1
def f32(*shape):
return rng.standard_normal(shape).astype(np.float32)
with open(output_path, 'wb') as fout:
fout.write(struct.pack("I", 0x67676d6c))
for key in ['n_vocab',
'n_audio_ctx',
'n_audio_state',
'n_audio_head',
'n_audio_layer',
'n_mels',
'ftype',
'n_fft',
'subsampling_factor',
'n_subsampling_channels',
'n_conv_kernel',
'n_pred_dim',
'n_pred_layers',
'n_tdt_durations',
'n_max_tokens']:
fout.write(struct.pack("i", hparams[key]))
fout.write(struct.pack("i", n_mels))
fout.write(struct.pack("i", n_freqs))
f32(n_mels, n_freqs).tofile(fout)
fout.write(struct.pack("i", n_fft))
f32(n_fft).tofile(fout)
for d in range(n_tdt):
fout.write(struct.pack("I", d))
tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)]
assert len(tokens) == n_vocab
fout.write(struct.pack("i", n_vocab))
for tok in tokens:
tok_bytes = tok.encode('utf-8')
fout.write(struct.pack("i", len(tok_bytes)))
fout.write(tok_bytes)
write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc))
write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state))
write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3))
write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3))
write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3))
write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1))
for i in range(n_layer):
p = f"encoder.layers.{i}"
write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state))
write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state))
write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state))
write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state))
write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state))
write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state))
write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state))
write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker))
write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state))
write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state))
write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state))
write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state)))
num_batches = np.zeros(1, dtype=np.int32)
write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches)
write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state))
write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state))
write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim))
write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim))
write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state))
write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state))
write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state))
write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state))
write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state))
write_tensor(fout, f"{p}.norm_out.weight", f32(n_state))
write_tensor(fout, f"{p}.norm_out.bias", f32(n_state))
write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim))
def reorder_gates(data):
h = data.shape[0] // 4
return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
for i in range(n_pred_l):
base = f"decoder.prediction.dec_rnn.lstm"
write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates)))
write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim))
write_tensor(fout, "joint.pred.bias", f32(dec_dim))
write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state))
write_tensor(fout, "joint.enc.bias", f32(dec_dim))
write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim))
write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out))
size = Path(output_path).stat().st_size
print(f"Generated {output_path} ({size / 1024:.1f} KB)")
if __name__ == '__main__':
output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin'
generate(output)

View File

@ -0,0 +1,3 @@
torch
numpy
pyyaml

15
scripts/quantize-parakeet.sh Executable file
View File

@ -0,0 +1,15 @@
#!/bin/bash
set -e
build_dir=build
modelname=ggml-parakeet-tdt-0.6b-v3
model=models/${modelname}-f32.bin
cmd=parakeet-quantize
cmake --build ${build_dir} --target $cmd -j 12
${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0
${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0
${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k
${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k

157
scripts/upload-parakeet.py Normal file
View File

@ -0,0 +1,157 @@
import argparse
import os
from huggingface_hub import HfApi, create_repo
USER_NAME = "ggml-org"
REPO_ID = f"{USER_NAME}/parakeet-GGUF"
MODELS = {
"f32": {
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin",
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin",
"description": "Full precision (F32)",
},
"f16": {
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin",
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin",
"description": "Half precision (F16)",
},
"q8_0": {
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
"description": "8-bit quantized (Q8_0)",
},
"q4_0": {
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
"description": "4-bit quantized (Q4_0)",
},
"q4_k": {
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
"description": "4-bit K-quantized (Q4_k)",
},
}
def build_model_card(uploaded_variants):
lines = [
f"---",
f"license: mit",
f"base_model: nvidia/parakeet-tdt-0.6b-v3",
f"tags:",
f"- gguf",
f"- asr",
f"---",
f"",
f"# Parakeet TDT 0.6B v3 (GGUF)",
f"",
f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).",
f"",
f"## Available files",
f"",
]
for key, m in MODELS.items():
if key in uploaded_variants:
lines.append(f"- `{m['remote_name']}` — {m['description']}")
lines += [
f"",
f"## Usage",
f"",
f"Build parakeet-cli:",
f"```console",
f"git clone https://github.com/ggml-org/whisper.cpp.git",
f"cd whisper.cpp",
f"cmake -B build -S .",
f"cmake --build build --target parakeet-cli -j $(nproc)",
f"```",
f"",
f"Download a model (e.g. Q8_0):",
f"```console",
f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models",
f"```",
f"",
f"Run:",
f"```console",
f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav",
f"```",
f"",
]
return "\n".join(lines)
def upload_variant(api, key):
m = MODELS[key]
local_path = m["local_path"]
if not os.path.exists(local_path):
print(f" Skipping {key}: {local_path} not found")
return False
print(f" Uploading {m['remote_name']} ({m['description']})...")
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=m["remote_name"],
repo_id=REPO_ID,
repo_type="model",
commit_message=f"Upload {m['remote_name']}",
)
return True
def main():
parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face")
parser.add_argument(
"variants",
nargs="*",
default=None,
metavar="{" + ",".join(MODELS.keys()) + "}",
help="Model variants to upload (default: all)",
)
parser.add_argument(
"--no-model-card",
action="store_true",
help="Skip updating the model card README",
)
args = parser.parse_args()
api = HfApi()
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
variants = args.variants if args.variants else list(MODELS.keys())
unknown = [v for v in variants if v not in MODELS]
if unknown:
parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})")
uploaded = []
for key in variants:
if upload_variant(api, key):
uploaded.append(key)
if not uploaded:
print("No models were uploaded.")
return
if not args.no_model_card:
print("Updating model card...")
existing = [k for k in MODELS if k in uploaded or
any(f.rfilename == MODELS[k]["remote_name"]
for f in api.list_repo_files(REPO_ID, repo_type="model")
if hasattr(f, "rfilename"))]
card = build_model_card(existing if existing else uploaded)
api.upload_file(
path_or_fileobj=card.encode(),
path_in_repo="README.md",
repo_id=REPO_ID,
repo_type="model",
commit_message="Update README.md",
)
print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}")
if __name__ == "__main__":
main()

View File

@ -109,23 +109,43 @@ add_library(whisper
whisper.cpp
)
add_library(parakeet
../include/parakeet.h
parakeet-arch.h
parakeet.cpp
)
target_include_directories(parakeet PUBLIC . ../include)
target_compile_features (parakeet PUBLIC cxx_std_11)
target_link_libraries(parakeet PUBLIC ggml Threads::Threads)
# Set the version numbers
set_target_properties(whisper PROPERTIES
VERSION ${PROJECT_VERSION}
SOVERSION ${SOVERSION}
)
set_target_properties(parakeet PROPERTIES
VERSION ${PROJECT_VERSION}
SOVERSION ${SOVERSION}
)
target_include_directories(whisper PUBLIC . ../include)
target_compile_features (whisper PUBLIC cxx_std_11) # don't bump
if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN)
set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN)
endif()
if (WHISPER_EXTRA_FLAGS)
target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS})
endif()
if (PARAKEET_EXTRA_FLAGS)
target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS})
endif()
find_package(Threads REQUIRED)
target_link_libraries(whisper PUBLIC ggml Threads::Threads)
@ -144,4 +164,7 @@ endif()
if (BUILD_SHARED_LIBS)
set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD)
set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD)
endif()

188
src/parakeet-arch.h Normal file
View File

@ -0,0 +1,188 @@
#pragma once
#include "ggml.h"
#include <map>
enum parakeet_tensor {
// Encoder pre_encode
PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_OUT_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS,
// Encoder layers (per-layer)
PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF1_BIAS,
PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT,
PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_CONV_BIAS,
PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_BN_BIAS,
PARAKEET_TENSOR_ENC_CONV_BN_MEAN,
PARAKEET_TENSOR_ENC_CONV_BN_VAR,
PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES,
PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS,
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U,
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V,
PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF2_BIAS,
PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT,
PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_OUT_BIAS,
// Prediction network
PARAKEET_TENSOR_PRED_EMBED_WEIGHT,
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH,
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH,
PARAKEET_TENSOR_PRED_LSTM_BIAS_H,
// Joint network
PARAKEET_TENSOR_JOINT_PRED_WEIGHT,
PARAKEET_TENSOR_JOINT_PRED_BIAS,
PARAKEET_TENSOR_JOINT_ENC_WEIGHT,
PARAKEET_TENSOR_JOINT_ENC_BIAS,
PARAKEET_TENSOR_JOINT_NET_WEIGHT,
PARAKEET_TENSOR_JOINT_NET_BIAS,
};
static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = {
// Encoder pre_encode
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"},
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"},
// Encoder layers (use %d for layer number)
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"},
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"},
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"},
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"},
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"},
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"},
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"},
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"},
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"},
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"},
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"},
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"},
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"},
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"},
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"},
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"},
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"},
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"},
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"},
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"},
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"},
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"},
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"},
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"},
// Prediction network
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"},
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"},
// Joint network
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"},
{PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"},
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"},
{PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"},
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"},
{PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"},
};
static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = {
// Encoder pre_encode
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD},
// Encoder layers
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB},
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV},
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE},
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD},
// Prediction network
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD},
// Joint network
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD},
};

3838
src/parakeet.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -118,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en")
# Parakeet model loading test
set(PARAKEET_TEST test-parakeet)
add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp)
target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples)
target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common)
target_compile_definitions(${PARAKEET_TEST} PRIVATE
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin"
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST})
set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh")
# The following parakeet test require a real ggml-parakeet-tdt model to have
# been converted or downloaded:
# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models
#
# And also required more audio samples that are shipped by default. These can
# downloaded by running:
# $ make samples
function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH)
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0")
if (ARGC GREATER 4)
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}")
endif()
add_executable(${TEST_TARGET} ${TEST_SOURCE})
target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples)
target_link_libraries(${TEST_TARGET} PRIVATE parakeet common)
target_compile_definitions(${TEST_TARGET} PRIVATE
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin"
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}"
EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}"
TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD})
add_custom_target(run-${TEST_TARGET}
COMMAND $<TARGET_FILE:${TEST_TARGET}>
DEPENDS ${TEST_TARGET}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
endfunction()
add_parakeet_transcription_test(
test-parakeet-full-jfk
test-parakeet-full.cpp
samples/jfk.wav
tests/parakeet-expected-jfk-output.txt)
add_parakeet_transcription_test(
test-parakeet-full-gb1
test-parakeet-full.cpp
samples/gb1.wav
tests/parakeet-expected-gb1-output.txt)
add_parakeet_transcription_test(
test-parakeet-full-diffusion
test-parakeet-full.cpp
samples/diffusion2023-07-03.flac
tests/parakeet-expected-diffusion-output.txt
0.95)

6
tests/librispeech-parakeet/.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
__pycache__
*.tar.gz
*.txt
eval.conf
venv
LibriSpeech

View File

@ -0,0 +1,15 @@
TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz
all: eval
eval:
$(MAKE) -f eval.mk
clean:
$(MAKE) -f eval.mk clean
get-audio:
wget -c $(TAR_URL)
tar -xf test-clean.tar.gz
.PHONY: all eval clean setup-venv clean-venv get-audio

View File

@ -0,0 +1,57 @@
# parakeet.cpp/tests/librispeech
[LibriSpeech](https://www.openslr.org/12) is a standard dataset for
training and evaluating automatic speech recognition systems.
This directory contains a set of tools to evaluate the recognition
performance of parakeet.cpp on LibriSpeech corpus.
## Quick Start
1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet
model in `ggml` format.
```
$ # Execute the commands below in the project root dir.
$ cmake -B build
$ cmake --build build --config Release
```
2. Download the audio files from LibriSpeech project.
```
$ make get-audio
```
3. Set up the environment to compute WER score.
```
$ pip install -r requirements.txt
```
For example, if you use `virtualenv`, you can set up it as follows:
```
$ python3 -m venv venv
$ . venv/bin/activate
$ pip install -r requirements.txt
```
4. Run the benchmark test.
```
$ make
```
## How-to guides
### How to change the inference parameters
Create `eval.conf` and override variables.
```
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt
```
Check out `eval.mk` for more details.

View File

@ -0,0 +1,39 @@
PYTHON = python
PARAKEET_PREFIX = ../../
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli
PARAKEET_FLAGS = --no-prints --output-txt
# You can create eval.conf to override the PARAKEET_* variables
# defined above.
-include eval.conf
# This follows the file structure of the LibriSpeech project.
AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac))
TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS))
# We output the evaluation result to this file.
DONE = $(PARAKEET_MODEL).txt
all: $(DONE)
$(DONE): $(TRANS_TXTS)
$(PYTHON) eval.py > $@.tmp
mv $@.tmp $@
# Note: This task writes to a temporary file first to
# create the target file atomically.
%.flac.txt: %.flac
$(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp
mv $^.tmp.txt $^.txt
archive:
tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE)
clean:
@rm -f $(TRANS_TXTS)
@rm -f $(DONE)
.PHONY: all clean

View File

@ -0,0 +1,47 @@
import os
import glob
import jiwer
from normalizers import EnglishTextNormalizer
def get_reference():
ref = {}
for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'):
with open(path) as fp:
for line in fp:
code, text = line.strip().split(" ", maxsplit=1)
ref [code] = text
return ref
def get_hypothesis():
hyp = {}
for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'):
with open(path) as fp:
text = fp.read().strip()
code = os.path.basename(path).replace('.flac.txt', '')
hyp[code] = text
return hyp
def get_codes():
codes = []
for path in glob.glob('LibriSpeech/*/*/*/*.flac'):
codes.append(os.path.basename(path).replace('.flac', ''))
return sorted(codes)
def main():
normalizer = EnglishTextNormalizer()
ref_orig = get_reference()
hyp_orig = get_hypothesis()
ref_clean = []
hyp_clean = []
for code in get_codes():
ref_clean.append(normalizer(ref_orig[code]))
hyp_clean.append(normalizer(hyp_orig[code]))
wer = jiwer.wer(ref_clean, hyp_clean)
print(f"WER: {wer * 100:.2f}%")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,25 @@
Code in this directory is adapted from OpenAI Whisper project
(https://github.com/openai/whisper) and carries the following
copyright and license.
MIT License
Copyright (c) 2022 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,2 @@
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer

View File

@ -0,0 +1,80 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
(
c
if c in keep
else (
ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else (
""
if unicodedata.category(c) == "Mn"
else " " if unicodedata.category(c)[0] in "MSP" else c
)
)
)
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,550 @@
import json
import os
import re
from fractions import Fraction
from typing import Iterator, List, Match, Optional, Union
from more_itertools import windowed
from .basic import remove_symbols_and_diacritics
class EnglishNumberNormalizer:
"""
Convert any spelled-out numbers into arabic numbers, while handling:
- remove any commas
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- spell out `one` and `ones`
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
"""
def __init__(self):
super().__init__()
self.zeros = {"o", "oh", "zero"}
self.ones = {
name: i
for i, name in enumerate(
[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
],
start=1,
)
}
self.ones_plural = {
"sixes" if name == "six" else name + "s": (value, "s")
for name, value in self.ones.items()
}
self.ones_ordinal = {
"zeroth": (0, "th"),
"first": (1, "st"),
"second": (2, "nd"),
"third": (3, "rd"),
"fifth": (5, "th"),
"twelfth": (12, "th"),
**{
name + ("h" if name.endswith("t") else "th"): (value, "th")
for name, value in self.ones.items()
if value > 3 and value != 5 and value != 12
},
}
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
self.tens = {
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
}
self.tens_plural = {
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
self.multipliers = {
"hundred": 100,
"thousand": 1_000,
"million": 1_000_000,
"billion": 1_000_000_000,
"trillion": 1_000_000_000_000,
"quadrillion": 1_000_000_000_000_000,
"quintillion": 1_000_000_000_000_000_000,
"sextillion": 1_000_000_000_000_000_000_000,
"septillion": 1_000_000_000_000_000_000_000_000,
"octillion": 1_000_000_000_000_000_000_000_000_000,
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
}
self.multipliers_plural = {
name + "s": (value, "s") for name, value in self.multipliers.items()
}
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
"minus": "-",
"negative": "-",
"plus": "+",
"positive": "+",
}
self.following_prefixers = {
"pound": "£",
"pounds": "£",
"euro": "",
"euros": "",
"dollar": "$",
"dollars": "$",
"cent": "¢",
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
"percent": "%",
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
key
for mapping in [
self.zeros,
self.ones,
self.ones_suffixed,
self.tens,
self.tens_suffixed,
self.multipliers,
self.multipliers_suffixed,
self.preceding_prefixers,
self.following_prefixers,
self.suffixers,
self.specials,
]
for key in mapping
]
)
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
prefix: Optional[str] = None
value: Optional[Union[str, int]] = None
skip = False
def to_fraction(s: str):
try:
return Fraction(s)
except ValueError:
return None
def output(result: Union[str, int]):
nonlocal prefix, value
result = str(result)
if prefix is not None:
result = prefix + result
value = None
prefix = None
return result
if len(words) == 0:
return
for prev, current, next in windowed([None] + words + [None], 3):
if skip:
skip = False
continue
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
has_prefix = current[0] in self.prefixes
current_without_prefix = current[1:] if has_prefix else current
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
# arabic numbers (potentially with signs and fractions)
f = to_fraction(current_without_prefix)
assert f is not None
if value is not None:
if isinstance(value, str) and value.endswith("."):
# concatenate decimals / ip address components
value = str(value) + str(current)
continue
else:
yield output(value)
prefix = current[0] if has_prefix else prefix
if f.denominator == 1:
value = f.numerator # store integers as int
else:
value = current_without_prefix
elif current not in self.words:
# non-numeric words
if value is not None:
yield output(value)
yield output(current)
elif current in self.zeros:
value = str(value or "") + "0"
elif current in self.ones:
ones = self.ones[current]
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
value = str(value) + str(ones)
elif ones < 10:
if value % 10 == 0:
value += ones
else:
value = str(value) + str(ones)
else: # eleven to nineteen
if value % 100 == 0:
value += ones
else:
value = str(value) + str(ones)
elif current in self.ones_suffixed:
# ordinal or cardinal; yield the number right away
ones, suffix = self.ones_suffixed[current]
if value is None:
yield output(str(ones) + suffix)
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10:
assert value[-1] == "0"
yield output(value[:-1] + str(ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
elif ones < 10:
if value % 10 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
else: # eleven to nineteen
if value % 100 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
value = None
elif current in self.tens:
tens = self.tens[current]
if value is None:
value = tens
elif isinstance(value, str):
value = str(value) + str(tens)
else:
if value % 100 == 0:
value += tens
else:
value = str(value) + str(tens)
elif current in self.tens_suffixed:
# ordinal or cardinal; yield the number right away
tens, suffix = self.tens_suffixed[current]
if value is None:
yield output(str(tens) + suffix)
elif isinstance(value, str):
yield output(str(value) + str(tens) + suffix)
else:
if value % 100 == 0:
yield output(str(value + tens) + suffix)
else:
yield output(str(value) + str(tens) + suffix)
elif current in self.multipliers:
multiplier = self.multipliers[current]
if value is None:
value = multiplier
elif isinstance(value, str) or value == 0:
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
value = p.numerator
else:
yield output(value)
value = multiplier
else:
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
elif current in self.multipliers_suffixed:
multiplier, suffix = self.multipliers_suffixed[current]
if value is None:
yield output(str(multiplier) + suffix)
elif isinstance(value, str):
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
yield output(str(p.numerator) + suffix)
else:
yield output(value)
yield output(str(multiplier) + suffix)
else: # int
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
yield output(str(value) + suffix)
value = None
elif current in self.preceding_prefixers:
# apply prefix (positive, minus, etc.) if it precedes a number
if value is not None:
yield output(value)
if next in self.words or next_is_numeric:
prefix = self.preceding_prefixers[current]
else:
yield output(current)
elif current in self.following_prefixers:
# apply prefix (dollars, cents, etc.) only after a number
if value is not None:
prefix = self.following_prefixers[current]
yield output(value)
else:
yield output(current)
elif current in self.suffixers:
# apply suffix symbols (percent -> '%')
if value is not None:
suffix = self.suffixers[current]
if isinstance(suffix, dict):
if next in suffix:
yield output(str(value) + suffix[next])
skip = True
else:
yield output(value)
yield output(current)
else:
yield output(str(value) + suffix)
else:
yield output(current)
elif current in self.specials:
if next not in self.words and not next_is_numeric:
# apply special handling only if the next word can be numeric
if value is not None:
yield output(value)
yield output(current)
elif current == "and":
# ignore "and" after hundreds, thousands, etc.
if prev not in self.multipliers:
if value is not None:
yield output(value)
yield output(current)
elif current == "double" or current == "triple":
if next in self.ones or next in self.zeros:
repeats = 2 if current == "double" else 3
ones = self.ones.get(next, 0)
value = str(value or "") + str(ones) * repeats
skip = True
else:
if value is not None:
yield output(value)
yield output(current)
elif current == "point":
if next in self.decimals or next_is_numeric:
value = str(value or "") + "."
else:
# should all have been covered at this point
raise ValueError(f"Unexpected token: {current}")
else:
# all should have been covered at this point
raise ValueError(f"Unexpected token: {current}")
if value is not None:
yield output(value)
def preprocess(self, s: str):
# replace "<number> and a half" with "<number> point five"
results = []
segments = re.split(r"\band\s+a\s+half\b", s)
for i, segment in enumerate(segments):
if len(segment.strip()) == 0:
continue
if i == len(segments) - 1:
results.append(segment)
else:
results.append(segment)
last_word = segment.rsplit(maxsplit=2)[-1]
if last_word in self.decimals or last_word in self.multipliers:
results.append("point five")
else:
results.append("and a half")
s = " ".join(results)
# put a space at number/letter boundary
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
# but remove spaces which could be a suffix
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
return s
def postprocess(self, s: str):
def combine_cents(m: Match):
try:
currency = m.group(1)
integer = m.group(2)
cents = int(m.group(3))
return f"{currency}{integer}.{cents:02d}"
except ValueError:
return m.string
def extract_cents(m: Match):
try:
return f"¢{int(m.group(1))}"
except ValueError:
return m.string
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
# write "one(s)" instead of "1(s)", just for the readability
s = re.sub(r"\b1(s?)\b", r"one\1", s)
return s
def __call__(self, s: str):
s = self.preprocess(s)
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
s = self.postprocess(s)
return s
class EnglishSpellingNormalizer:
"""
Applies British-American spelling mappings as listed in [1].
[1] https://www.tysto.com/uk-us-spelling-list.html
"""
def __init__(self):
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
self.mapping = json.load(open(mapping_path))
def __call__(self, s: str):
return " ".join(self.mapping.get(word, word) for word in s.split())
class EnglishTextNormalizer:
def __init__(self):
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
self.replacers = {
# common contractions
r"\bwon't\b": "will not",
r"\bcan't\b": "can not",
r"\blet's\b": "let us",
r"\bain't\b": "aint",
r"\by'all\b": "you all",
r"\bwanna\b": "want to",
r"\bgotta\b": "got to",
r"\bgonna\b": "going to",
r"\bi'ma\b": "i am going to",
r"\bimma\b": "i am going to",
r"\bwoulda\b": "would have",
r"\bcoulda\b": "could have",
r"\bshoulda\b": "should have",
r"\bma'am\b": "madam",
# contractions in titles/prefixes
r"\bmr\b": "mister ",
r"\bmrs\b": "missus ",
r"\bst\b": "saint ",
r"\bdr\b": "doctor ",
r"\bprof\b": "professor ",
r"\bcapt\b": "captain ",
r"\bgov\b": "governor ",
r"\bald\b": "alderman ",
r"\bgen\b": "general ",
r"\bsen\b": "senator ",
r"\brep\b": "representative ",
r"\bpres\b": "president ",
r"\brev\b": "reverend ",
r"\bhon\b": "honorable ",
r"\basst\b": "assistant ",
r"\bassoc\b": "associate ",
r"\blt\b": "lieutenant ",
r"\bcol\b": "colonel ",
r"\bjr\b": "junior ",
r"\bsr\b": "senior ",
r"\besq\b": "esquire ",
# prefect tenses, ideally it should be any past participles, but it's harder..
r"'d been\b": " had been",
r"'s been\b": " has been",
r"'d gone\b": " had gone",
r"'s gone\b": " has gone",
r"'d done\b": " had done", # "'s done" is ambiguous
r"'s got\b": " has got",
# general contractions
r"n't\b": " not",
r"'re\b": " are",
r"'s\b": " is",
r"'d\b": " would",
r"'ll\b": " will",
r"'t\b": " not",
r"'ve\b": " have",
r"'m\b": " am",
}
self.standardize_numbers = EnglishNumberNormalizer()
self.standardize_spellings = EnglishSpellingNormalizer()
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
# now remove prefix/suffix symbols that are not preceded/followed by numbers
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America.

View File

@ -0,0 +1 @@
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.

View File

@ -0,0 +1,110 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdio>
#include <fstream>
#include <iterator>
#include <string>
#include <vector>
#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD
#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0
#endif
static std::string read_expected_transcription(const char * path) {
std::ifstream fin(path);
assert(fin.is_open());
std::string text(
(std::istreambuf_iterator<char>(fin)),
std::istreambuf_iterator<char>());
while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) {
text.pop_back();
}
return text;
}
static std::vector<std::string> transcription_words(const std::string & text) {
std::vector<std::string> words;
std::string word;
for (unsigned char ch : text) {
if (std::isalnum(ch)) {
word.push_back((char) std::tolower(ch));
} else if (!word.empty()) {
words.push_back(word);
word.clear();
}
}
if (!word.empty()) {
words.push_back(word);
}
return words;
}
static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) {
const std::vector<std::string> expected_words = transcription_words(expected);
const std::vector<std::string> actual_words = transcription_words(actual);
if (expected_words.empty() && actual_words.empty()) {
return 1.0;
}
if (expected_words.empty() || actual_words.empty()) {
return 0.0;
}
std::vector<int> prev(actual_words.size() + 1, 0);
std::vector<int> cur (actual_words.size() + 1, 0);
for (size_t i = 0; i < expected_words.size(); ++i) {
std::fill(cur.begin(), cur.end(), 0);
for (size_t j = 0; j < actual_words.size(); ++j) {
if (expected_words[i] == actual_words[j]) {
cur[j + 1] = prev[j] + 1;
} else {
cur[j + 1] = std::max(prev[j + 1], cur[j]);
}
}
prev.swap(cur);
}
const int lcs = prev[actual_words.size()];
return (2.0 * lcs) / (expected_words.size() + actual_words.size());
}
static bool verify_transcription(const std::string & expected, const std::string & actual) {
const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD;
if (threshold >= 1.0) {
if (actual == expected) {
return true;
}
fprintf(stderr, "\n\n");
fprintf(stderr, "[Failed] Transcript mismatched\n");
fprintf(stderr, "expected:\n%s\n\n", expected.c_str());
fprintf(stderr, "actual:\n%s\n", actual.c_str());
return false;
}
const double similarity = transcription_lcs_similarity(expected, actual);
printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold);
if (similarity >= threshold) {
return true;
}
fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold);
fprintf(stderr, "Expected:\n%s\n\n", expected.c_str());
fprintf(stderr, "Actual:\n%s\n", actual.c_str());
return false;
}

View File

@ -21,13 +21,21 @@ cd `dirname $0`
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" )
# Parakeet model variants
parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" )
# list available models
function list_models {
printf "\n"
printf " Available models:"
printf " Available whisper models:"
for model in "${models[@]}"; do
printf " $model"
done
printf "\n"
printf " Available parakeet models:"
for model in "${parakeet_models[@]}"; do
printf " parakeet-$model"
done
printf "\n\n"
}
@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then
fi
model=$1
main="../build/bin/whisper-cli"
threads=""
if [ $# -eq 2 ]; then
threads="-t $2"
fi
if [ ! -f ../models/ggml-$model.bin ]; then
printf "Model $model not found. Aborting\n"
# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32")
is_parakeet=0
parakeet_variant=""
if [[ $model == parakeet-* ]]; then
is_parakeet=1
parakeet_variant="${model#parakeet-}"
fi
for v in "${parakeet_models[@]}"; do
if [[ $model == "$v" ]]; then
is_parakeet=1
parakeet_variant="$v"
break
fi
done
if [ $is_parakeet -eq 1 ]; then
main="../build/bin/parakeet-cli"
model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin"
else
main="../build/bin/whisper-cli"
model_path="../models/ggml-${model}.bin"
fi
if [ ! -f $model_path ]; then
printf "Model $model not found ($model_path). Aborting\n"
list_models
exit 1
fi
@ -110,7 +140,11 @@ function run_lang() {
fi
fi
$main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null
if [ $is_parakeet -eq 1 ]; then
$main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null
else
$main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null
fi
git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt
@ -120,7 +154,7 @@ function run_lang() {
run_lang "en" "${urls_en[@]}"
if [[ $model != *.en* ]]; then
if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then
run_lang "es" "${urls_es[@]}"
run_lang "it" "${urls_it[@]}"
run_lang "pt" "${urls_pt[@]}"

View File

@ -0,0 +1,101 @@
#include "parakeet.h"
#include "common-whisper.h"
#include "parakeet-verification.h"
#include <cstdio>
#include <string>
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <cassert>
struct test_state {
bool is_first = true;
std::string transcript;
};
void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) {
bool * called = static_cast<bool *>(user_data);
*called = true;
}
bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) {
bool * called = static_cast<bool *>(user_data);
*called = true;
return true;
}
bool abort_callback(void * user_data) {
bool * called = static_cast<bool *>(user_data);
*called = true;
return false; // just continue without aborting.
}
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
test_state * tstate = static_cast<test_state *>(user_data);
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
char text_buf[256];
parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf));
printf("%s", text_buf);
fflush(stdout);
tstate->transcript += text_buf;
tstate->is_first = false;
}
int main() {
std::string model_path = PARAKEET_MODEL_PATH;
std::string sample_path = SAMPLE_PATH;
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
assert(pcmf32.size() > 0);
assert(pcmf32s.size() == 0); // no stereo vector
printf("Loading Parakeet model from: %s\n", model_path.c_str());
struct parakeet_context_params ctx_params = parakeet_context_default_params();
struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params);
if (pctx == nullptr) {
fprintf(stderr, "Failed to load Parakeet model\n");
return 1;
}
printf("Successfully loaded Parakeet model\n");
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
test_state tstate;
params.new_token_callback = token_callback;
params.new_token_callback_user_data = &tstate;
bool progress_callback_called = false;
params.progress_callback = progress_callback;
params.progress_callback_user_data = &progress_callback_called;
bool encoder_begin_callback_called = false;
params.encoder_begin_callback = encoder_begin_callback;
params.encoder_begin_callback_user_data = &encoder_begin_callback_called;
bool abort_callback_called = false;
params.abort_callback = abort_callback;
params.abort_callback_user_data = &abort_callback_called;
int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size());
assert(ret == 0);
assert(progress_callback_called);
assert(encoder_begin_callback_called);
assert(abort_callback_called);
const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH);
const bool transcript_matches = verify_transcription(expected, tstate.transcript);
parakeet_free(pctx);
if (!transcript_matches) {
return 1;
}
printf("\nTest passed: parakeet_full succeeded!\n");
return 0;
}

99
tests/test-parakeet.cpp Normal file
View File

@ -0,0 +1,99 @@
#include "parakeet.h"
#include "common-whisper.h"
#include <cstdio>
#include <string>
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <cassert>
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
static bool is_first = true;
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
char text_buf[256];
parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf));
int32_t time_ms = token_data->frame_index * 10;
printf("%s", text_buf);
fflush(stdout);
is_first = false;
}
void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) {
const int n_segments = parakeet_full_n_segments_from_state(state);
const int s0 = n_segments - n_new;
printf("\nSegment Callback: %d new segment(s)\n", n_new);
for (int i = s0; i < n_segments; i++) {
const char * text = parakeet_full_get_segment_text_from_state(state, i);
const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i);
const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i);
printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
printf("Tokens:\n");
const int n_tokens = parakeet_full_n_tokens_from_state(state, i);
for (int j = 0; j < n_tokens; j++) {
parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j);
const char * token_str = parakeet_token_to_str(ctx, token_data.id);
printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n",
j,
token_data.id,
token_data.frame_index,
token_data.duration_idx,
token_data.duration_value,
token_data.p,
token_data.plog,
(long long)token_data.t0,
(long long)token_data.t1,
token_data.is_word_start,
token_str);
}
}
printf("\n");
}
int main() {
std::string model_path = PARAKEET_MODEL_PATH;
std::string sample_path = SAMPLE_PATH;
// Load the sample audio file
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
assert(pcmf32.size() > 0);
assert(pcmf32s.size() == 0);
printf("Loading Parakeet model from: %s\n", model_path.c_str());
struct parakeet_context_params ctx_params = parakeet_context_default_params();
struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params);
if (pctx == nullptr) {
fprintf(stderr, "Failed to load Parakeet model\n");
return 1;
}
printf("Successfully loaded Parakeet model\n");
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
params.new_token_callback = token_callback;
params.new_token_callback_user_data = nullptr;
params.new_segment_callback = segment_callback;
params.new_segment_callback_user_data = nullptr;
parakeet_state * state = parakeet_init_state(pctx);
int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size());
assert(ret == 0);
parakeet_free_state(state);
parakeet_free(pctx);
printf("\nTest passed: Parakeet model loaded and freed successfully\n");
return 0;
}