parakeet : add support for NVIDIA Parakeet (#3735)
* parakeet : add support for NVIDIA Parakeet Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
3805e602d3
commit
9efddafb91
|
|
@ -180,12 +180,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location
|
|||
get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
|
||||
|
||||
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
|
||||
|
||||
install(TARGETS whisper LIBRARY PUBLIC_HEADER)
|
||||
|
||||
target_compile_definitions(whisper PRIVATE
|
||||
WHISPER_VERSION="${PROJECT_VERSION}"
|
||||
)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h)
|
||||
install(TARGETS parakeet LIBRARY PUBLIC_HEADER)
|
||||
|
||||
target_compile_definitions(parakeet PRIVATE
|
||||
PARAKEET_VERSION="${PROJECT_VERSION}"
|
||||
)
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
|
||||
|
|
@ -211,6 +219,35 @@ configure_file(cmake/whisper.pc.in
|
|||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc"
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||
|
||||
set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet
|
||||
PATH_VARS
|
||||
PARAKEET_INCLUDE_INSTALL_DIR
|
||||
PARAKEET_LIB_INSTALL_DIR
|
||||
PARAKEET_BIN_INSTALL_DIR)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
|
||||
VERSION ${WHISPER_INSTALL_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet)
|
||||
|
||||
configure_file(cmake/parakeet.pc.in
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
|
||||
@ONLY)
|
||||
|
||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc"
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||
|
||||
#
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
|
|
|||
|
|
@ -30,6 +30,6 @@ create_makefile "whisper" do |conf|
|
|||
#{libs}: cmake-targets
|
||||
cmake-targets:
|
||||
#{"\t"}"#{cmake}" -S sources -B build #{options}
|
||||
#{"\t"}"#{cmake}" --build build --config Release --target common whisper
|
||||
#{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet
|
||||
EOF
|
||||
end
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@)
|
||||
set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@)
|
||||
set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@)
|
||||
set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@")
|
||||
set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake)
|
||||
|
||||
find_library(parakeet_LIBRARY parakeet
|
||||
REQUIRED
|
||||
HINTS ${PARAKEET_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
|
||||
add_library(parakeet UNKNOWN IMPORTED)
|
||||
set_target_properties(parakeet
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}"
|
||||
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${parakeet_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES cxx_std_11
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
check_required_components(parakeet)
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
prefix=@CMAKE_INSTALL_PREFIX@
|
||||
exec_prefix=${prefix}
|
||||
libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@
|
||||
includedir=${prefix}/include
|
||||
|
||||
Name: parakeet
|
||||
Description: Port of NVIDIA's Parakeet model in C/C++
|
||||
Version: @PROJECT_VERSION@
|
||||
Libs: -L${libdir} -lggml -lggml-base -lparakeet
|
||||
Cflags: -I${includedir}
|
||||
|
|
@ -107,6 +107,8 @@ else()
|
|||
add_subdirectory(server)
|
||||
add_subdirectory(quantize)
|
||||
add_subdirectory(vad-speech-segments)
|
||||
add_subdirectory(parakeet-cli)
|
||||
add_subdirectory(parakeet-quantize)
|
||||
if (WHISPER_SDL2)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
set(TARGET parakeet-cli)
|
||||
add_executable(${TARGET} parakeet-cli.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
# whisper.cpp/examples/parakeet-cli
|
||||
|
||||
This is an example of using the [Parakeet] model in whisper.cpp.
|
||||
|
||||
### Download converted model
|
||||
```console
|
||||
$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models
|
||||
```
|
||||
|
||||
### Building
|
||||
```console
|
||||
$ cmake -B build -S .
|
||||
$ cmake --build build --target parakeet-cli -j 12
|
||||
```
|
||||
|
||||
### Usage
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli --help
|
||||
|
||||
usage: ./build/bin/parakeet-cli [options] file0 file1 ...
|
||||
supported audio formats: flac, mp3, ogg, wav
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path
|
||||
-f, --file FILE [ ] input audio file
|
||||
-ng, --no-gpu [false ] disable GPU
|
||||
-dev N, --device N [0 ] GPU device to use
|
||||
-ps, --print-segments [false ] print segment information
|
||||
```
|
||||
|
||||
### Example
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav
|
||||
Processing audio (176000 samples, 11.00 seconds)
|
||||
Processing audio: total_frames=1101, chunk_size=1101
|
||||
parakeet_decode: starting decode with n_frames=138
|
||||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
```
|
||||
|
||||
To print segment information:
|
||||
```console
|
||||
$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments
|
||||
Processing audio (176000 samples, 11.00 seconds)
|
||||
Processing audio: total_frames=1101, chunk_size=1101
|
||||
parakeet_decode: starting decode with n_frames=138
|
||||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
||||
Segments (1):
|
||||
Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country."
|
||||
Tokens [38]:
|
||||
[ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And"
|
||||
[ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so"
|
||||
[ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false ","
|
||||
[ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my"
|
||||
[ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f"
|
||||
[ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell"
|
||||
[ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow"
|
||||
[ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer"
|
||||
[ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic"
|
||||
[ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans"
|
||||
[10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false ","
|
||||
[11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a"
|
||||
[12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk"
|
||||
[13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not"
|
||||
[14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what"
|
||||
[15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your"
|
||||
[16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co"
|
||||
[17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un"
|
||||
[18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr"
|
||||
[19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y"
|
||||
[20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can"
|
||||
[21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do"
|
||||
[22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for"
|
||||
[23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you"
|
||||
[24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false ","
|
||||
[25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a"
|
||||
[26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk"
|
||||
[27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what"
|
||||
[28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you"
|
||||
[29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can"
|
||||
[30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do"
|
||||
[31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for"
|
||||
[32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your"
|
||||
[33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co"
|
||||
[34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un"
|
||||
[35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr"
|
||||
[36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y"
|
||||
[37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "."
|
||||
```
|
||||
|
||||
### Model conversion
|
||||
Clone the original model from Hugging Face:
|
||||
```console
|
||||
$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
|
||||
```
|
||||
Convert the model:
|
||||
```console
|
||||
(venv) $ python models/convert-parakeet-to-ggml.py \
|
||||
--model <path to cloned model> \
|
||||
--out-dir models \
|
||||
--out-name ggml-parakeet-tdt-0.6b-v3-f16.bin
|
||||
```
|
||||
|
||||
[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
|
||||
// command-line parameters
|
||||
struct parakeet_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
|
||||
bool use_gpu = true;
|
||||
int32_t gpu_device = 0;
|
||||
|
||||
bool print_segments = false;
|
||||
bool output_txt = false;
|
||||
bool no_prints = false;
|
||||
|
||||
std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin";
|
||||
std::string output_file = "";
|
||||
std::vector<std::string> fname_inp = {};
|
||||
};
|
||||
|
||||
static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params);
|
||||
|
||||
static char * requires_value_error(const std::string & arg) {
|
||||
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) {
|
||||
if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) {
|
||||
params.gpu_device = std::stoi(env_device);
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-"){
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg[0] != '-') {
|
||||
params.fname_inp.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; }
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]);
|
||||
fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", "");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device);
|
||||
fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
bool * is_first = (bool *) user_data;
|
||||
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf));
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
*is_first = false;
|
||||
}
|
||||
|
||||
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
parakeet_params params;
|
||||
|
||||
if (parakeet_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.no_prints) {
|
||||
parakeet_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
|
||||
if (params.fname_inp.empty()) {
|
||||
fprintf(stderr, "error: no input files specified\n");
|
||||
parakeet_print_usage(argc, argv, params);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
ctx_params.use_gpu = params.use_gpu;
|
||||
ctx_params.gpu_device = params.gpu_device;
|
||||
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str());
|
||||
}
|
||||
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Successfully loaded Parakeet model\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info());
|
||||
}
|
||||
|
||||
// Process each input file
|
||||
for (const auto & fname : params.fname_inp) {
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "\nProcessing file: %s\n", fname.c_str());
|
||||
}
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) {
|
||||
fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pcmf32.empty()) {
|
||||
fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_first = true;
|
||||
struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
full_params.n_threads = params.n_threads;
|
||||
full_params.new_token_callback = token_callback;
|
||||
full_params.new_token_callback_user_data = &is_first;
|
||||
|
||||
const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH);
|
||||
int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size());
|
||||
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
if (params.output_txt) {
|
||||
const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt";
|
||||
|
||||
std::ofstream fout(fname_out);
|
||||
if (fout.is_open()) {
|
||||
const int n_segments = parakeet_full_n_segments(pctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = parakeet_full_get_segment_text(pctx, i);
|
||||
fout << text << "\n";
|
||||
}
|
||||
fout.close();
|
||||
if (!params.no_prints) {
|
||||
fprintf(stderr, "Output written to: %s\n", fname_out.c_str());
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.no_prints) {
|
||||
parakeet_print_timings(pctx);
|
||||
}
|
||||
|
||||
if (params.print_segments) {
|
||||
const int n_segments = parakeet_full_n_segments(pctx);
|
||||
fprintf(stderr, "\nSegments (%d):\n", n_segments);
|
||||
|
||||
for (int i = 0; i < n_segments; i++) {
|
||||
const char * text = parakeet_full_get_segment_text(pctx, i);
|
||||
const int64_t t0 = parakeet_full_get_segment_t0(pctx, i);
|
||||
const int64_t t1 = parakeet_full_get_segment_t1(pctx, i);
|
||||
const int n_tokens = parakeet_full_n_tokens(pctx, i);
|
||||
|
||||
fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
|
||||
fprintf(stderr, "Tokens [%d]:\n", n_tokens);
|
||||
|
||||
for (int j = 0; j < n_tokens; j++) {
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j);
|
||||
const char * token_str = parakeet_token_to_str(pctx, token_data.id);
|
||||
|
||||
fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n",
|
||||
j,
|
||||
token_data.id,
|
||||
token_data.frame_index,
|
||||
token_data.duration_idx,
|
||||
token_data.duration_value,
|
||||
token_data.p,
|
||||
token_data.plog,
|
||||
(long long)token_data.t0,
|
||||
(long long)token_data.t1,
|
||||
token_data.is_word_start ? "true": "false",
|
||||
token_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parakeet_free(pctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
set(TARGET parakeet-quantize)
|
||||
add_executable(${TARGET} parakeet-quantize.cpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include "common-ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct parakeet_hparams {
|
||||
int32_t n_vocab = 0;
|
||||
int32_t n_audio_ctx = 0;
|
||||
int32_t n_audio_state = 0;
|
||||
int32_t n_audio_head = 0;
|
||||
int32_t n_audio_layer = 0;
|
||||
int32_t n_mels = 0;
|
||||
int32_t ftype = 0;
|
||||
int32_t n_fft = 0;
|
||||
int32_t subsampling_factor = 0;
|
||||
int32_t n_subsampling_channels = 0;
|
||||
int32_t n_conv_kernel = 0;
|
||||
int32_t n_pred_dim = 0;
|
||||
int32_t n_pred_layers = 0;
|
||||
int32_t n_tdt_durations = 0;
|
||||
int32_t n_max_tokens = 0;
|
||||
};
|
||||
|
||||
static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
|
||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||
|
||||
auto finp = std::ifstream(fname_inp, std::ios::binary);
|
||||
if (!finp) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fout = std::ofstream(fname_out, std::ios::binary);
|
||||
if (!fout) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// magic
|
||||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
if (magic != GGML_FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__);
|
||||
return false;
|
||||
}
|
||||
fout.write((char *) &magic, sizeof(magic));
|
||||
}
|
||||
|
||||
// hparams
|
||||
parakeet_hparams hparams;
|
||||
{
|
||||
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
|
||||
finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft));
|
||||
finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
|
||||
finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
|
||||
finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
|
||||
finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
|
||||
finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
|
||||
finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
|
||||
finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
|
||||
|
||||
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
|
||||
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
|
||||
|
||||
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
||||
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
||||
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
||||
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
|
||||
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
|
||||
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
|
||||
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
|
||||
|
||||
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
|
||||
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
|
||||
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
|
||||
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
|
||||
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
|
||||
fout.write((char *) &ftype_dst, sizeof(ftype_dst));
|
||||
fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft));
|
||||
fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor));
|
||||
fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels));
|
||||
fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel));
|
||||
fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim));
|
||||
fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers));
|
||||
fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations));
|
||||
fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens));
|
||||
}
|
||||
|
||||
// mel filterbank
|
||||
{
|
||||
int32_t n_mel, n_fb;
|
||||
finp.read((char *) &n_mel, sizeof(n_mel));
|
||||
fout.write((char *) &n_mel, sizeof(n_mel));
|
||||
finp.read((char *) &n_fb, sizeof(n_fb));
|
||||
fout.write((char *) &n_fb, sizeof(n_fb));
|
||||
|
||||
const size_t n = (size_t) n_mel * n_fb;
|
||||
std::vector<float> buf(n);
|
||||
finp.read((char *) buf.data(), n * sizeof(float));
|
||||
fout.write((char *) buf.data(), n * sizeof(float));
|
||||
}
|
||||
|
||||
// window function
|
||||
{
|
||||
int32_t n_window;
|
||||
finp.read((char *) &n_window, sizeof(n_window));
|
||||
fout.write((char *) &n_window, sizeof(n_window));
|
||||
|
||||
std::vector<float> buf(n_window);
|
||||
finp.read((char *) buf.data(), n_window * sizeof(float));
|
||||
fout.write((char *) buf.data(), n_window * sizeof(float));
|
||||
}
|
||||
|
||||
// TDT durations
|
||||
{
|
||||
std::vector<uint32_t> buf(hparams.n_tdt_durations);
|
||||
finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
|
||||
fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// vocab
|
||||
{
|
||||
int32_t n_tokens;
|
||||
finp.read((char *) &n_tokens, sizeof(n_tokens));
|
||||
fout.write((char *) &n_tokens, sizeof(n_tokens));
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
int32_t len;
|
||||
finp.read((char *) &len, sizeof(len));
|
||||
fout.write((char *) &len, sizeof(len));
|
||||
|
||||
std::string token(len, '\0');
|
||||
finp.read(&token[0], len);
|
||||
fout.write(&token[0], len);
|
||||
}
|
||||
}
|
||||
|
||||
// tensors — quantize 2D weights skipping tensors that must stay F32:
|
||||
// ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights.
|
||||
// pos_bias_u / pos_bias_v are declared F32 in the loader.
|
||||
const std::vector<std::string> to_quant = { ".*" };
|
||||
std::vector<std::string> to_skip = {
|
||||
// CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights)
|
||||
"encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight",
|
||||
// Declared F32 in loader (pos_bias tensors)
|
||||
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_u",
|
||||
"encoder\\.layers\\..+\\.self_attn\\.pos_bias_v",
|
||||
};
|
||||
|
||||
// Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant
|
||||
// types (block size 256) cannot quantize 640 evenly, so keep them F32. For
|
||||
// other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be
|
||||
// quantized normally. The loader mirrors this logic at load time.
|
||||
{
|
||||
const ggml_type qtype = ggml_ftype_to_ggml_type(ftype);
|
||||
const int32_t blck = ggml_blck_size(qtype);
|
||||
if (blck > 1 && hparams.n_pred_dim % blck != 0) {
|
||||
to_skip.push_back("decoder\\.prediction\\.embed\\.weight");
|
||||
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*");
|
||||
to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*");
|
||||
to_skip.push_back("joint\\.pred\\.weight");
|
||||
to_skip.push_back("joint\\.joint_net\\.2\\.weight");
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) {
|
||||
fprintf(stderr, "%s: failed to quantize tensors\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
finp.close();
|
||||
fout.close();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (argc != 4) {
|
||||
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// initialise F16 lookup tables
|
||||
{
|
||||
struct ggml_init_params params = { 0, NULL, false };
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
const std::string fname_inp = argv[1];
|
||||
const std::string fname_out = argv[2];
|
||||
const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
|
||||
|
||||
if (ftype == GGML_FTYPE_UNKNOWN) {
|
||||
fprintf(stderr, "%s: invalid quantization type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) {
|
||||
fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f);
|
||||
printf("%s: output model = %s\n", argv[0], fname_out.c_str());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
#ifndef PARAKEET_H
|
||||
#define PARAKEET_H
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __GNUC__
|
||||
# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
||||
#elif defined(_MSC_VER)
|
||||
# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
||||
#else
|
||||
# define PARAKEET_DEPRECATED(func, hint) func
|
||||
#endif
|
||||
|
||||
#ifdef PARAKEET_SHARED
|
||||
# ifdef _WIN32
|
||||
# ifdef PARAKEET_BUILD
|
||||
# define PARAKEET_API __declspec(dllexport)
|
||||
# else
|
||||
# define PARAKEET_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define PARAKEET_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define PARAKEET_API
|
||||
#endif
|
||||
|
||||
#define PARAKEET_SAMPLE_RATE 16000
|
||||
#define PARAKEET_HOP_LENGTH 160
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct parakeet_context;
|
||||
struct parakeet_state;
|
||||
struct parakeet_full_params;
|
||||
|
||||
typedef int32_t parakeet_pos;
|
||||
typedef int32_t parakeet_token;
|
||||
typedef int32_t parakeet_seq_id;
|
||||
|
||||
struct parakeet_context_params {
|
||||
bool use_gpu;
|
||||
int gpu_device; // CUDA device
|
||||
};
|
||||
|
||||
typedef struct parakeet_token_data {
|
||||
parakeet_token id; // the BPE subword ID (0-8191)
|
||||
|
||||
int duration_idx; // index into the models durations array
|
||||
int duration_value; // actual duration value
|
||||
int frame_index;
|
||||
|
||||
float p;
|
||||
float plog;
|
||||
|
||||
int64_t t0;
|
||||
int64_t t1;
|
||||
|
||||
bool is_word_start;
|
||||
} parakeet_token_data;
|
||||
|
||||
typedef struct parakeet_model_loader {
|
||||
void * context;
|
||||
|
||||
size_t (*read)(void * ctx, void * output, size_t read_size);
|
||||
bool (*eof)(void * ctx);
|
||||
void (*close)(void * ctx);
|
||||
} parakeet_model_loader;
|
||||
|
||||
PARAKEET_API const char * parakeet_version(void);
|
||||
|
||||
// Various functions for loading a ggml parakeet model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params);
|
||||
|
||||
// These are the same as the above, but the internal state of the context is not allocated automatically
|
||||
// It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523)
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params);
|
||||
PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params);
|
||||
|
||||
PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx);
|
||||
|
||||
// Frees all allocated memory
|
||||
PARAKEET_API void parakeet_free (struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_free_state(struct parakeet_state * state);
|
||||
PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params);
|
||||
PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params);
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the default state of the provided parakeet context.
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_pcm_to_mel(
|
||||
struct parakeet_context * ctx,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
PARAKEET_API int parakeet_pcm_to_mel_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const float * samples,
|
||||
int n_samples,
|
||||
int n_threads);
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context.
|
||||
// Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 128
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_set_mel(
|
||||
struct parakeet_context * ctx,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
PARAKEET_API int parakeet_set_mel_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const float * data,
|
||||
int n_len,
|
||||
int n_mel);
|
||||
|
||||
// Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context.
|
||||
// Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
// Returns 0 on success
|
||||
PARAKEET_API int parakeet_encode(
|
||||
struct parakeet_context * ctx,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
PARAKEET_API int parakeet_encode_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
int offset,
|
||||
int n_threads);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
PARAKEET_API int parakeet_tokenize(
|
||||
struct parakeet_context * ctx,
|
||||
const char * text,
|
||||
parakeet_token * tokens,
|
||||
int n_max_tokens);
|
||||
|
||||
// Return the number of tokens in the provided text
|
||||
// Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0)
|
||||
int parakeet_token_count(struct parakeet_context * ctx, const char * text);
|
||||
|
||||
PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length
|
||||
PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length
|
||||
PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx);
|
||||
|
||||
PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to parakeet_full/parakeet_chunk
|
||||
// The logits for the last token are stored in the last row
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx);
|
||||
PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token);
|
||||
|
||||
PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len);
|
||||
|
||||
// Special tokens
|
||||
PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx);
|
||||
PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx);
|
||||
PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx);
|
||||
|
||||
// Performance information from the default state.
|
||||
struct parakeet_timings {
|
||||
float sample_ms;
|
||||
float encode_ms;
|
||||
float decode_ms;
|
||||
};
|
||||
PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx);
|
||||
PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx);
|
||||
|
||||
// Print system information
|
||||
PARAKEET_API const char * parakeet_print_system_info(void);
|
||||
|
||||
// Available sampling strategies
|
||||
enum parakeet_sampling_strategy {
|
||||
PARAKEET_SAMPLING_GREEDY,
|
||||
};
|
||||
|
||||
// Token callback.
|
||||
// Called for each new predicted token.
|
||||
// Use the parakeet_full_...() functions to obtain the text segments
|
||||
typedef void (*parakeet_new_token_callback)(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
const parakeet_token_data * token_data,
|
||||
void * user_data);
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the parakeet_full_...() functions to obtain the text segments
|
||||
typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data);
|
||||
|
||||
// Progress callback
|
||||
typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data);
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data);
|
||||
|
||||
// Parameters for the parakeet_full() function
|
||||
// If you change the order or add new parameters, make sure to update the default values in parakeet.cpp:
|
||||
// parakeet_full_default_params()
|
||||
struct parakeet_full_params {
|
||||
enum parakeet_sampling_strategy strategy;
|
||||
|
||||
int n_threads;
|
||||
int offset_ms; // start offset in ms
|
||||
int duration_ms; // audio duration to process in ms
|
||||
|
||||
bool no_context; // do not use past transcription (if any) as context
|
||||
|
||||
int audio_ctx; // overwrite the audio context size (0 = use default)
|
||||
|
||||
// called for every newly generated text segment
|
||||
parakeet_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
// called for every newly generated token
|
||||
parakeet_new_token_callback new_token_callback;
|
||||
void * new_token_callback_user_data;
|
||||
|
||||
// called on each progress update
|
||||
parakeet_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// called each time before the encoder starts
|
||||
parakeet_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
|
||||
// called each time before ggml computation starts
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_user_data;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params()
|
||||
PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void);
|
||||
PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void);
|
||||
|
||||
PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy);
|
||||
PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy);
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Not thread safe for same context
|
||||
PARAKEET_API int parakeet_full(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
PARAKEET_API int parakeet_full_with_state(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
// Process a single chunk of audio data that fits within the model's audio context window.
|
||||
// This is more efficient than parakeet_full() for short audio clips.
|
||||
PARAKEET_API int parakeet_chunk(
|
||||
struct parakeet_context * ctx,
|
||||
struct parakeet_state * state,
|
||||
struct parakeet_full_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
// Number of generated text segments
|
||||
PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx);
|
||||
PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state);
|
||||
|
||||
// Get the start and end time of the specified segment
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get the text of the specified segment
|
||||
PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get number of tokens in the specified segment
|
||||
PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment);
|
||||
PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment);
|
||||
|
||||
// Get the token text of the specified token in the specified segment
|
||||
PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get the token id of the specified token in the specified segment
|
||||
PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get token data for the specified token in the specified segment
|
||||
PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Get the probability of the specified token in the specified segment
|
||||
PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token);
|
||||
PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token);
|
||||
|
||||
// Control logging output; default behavior is to print to stderr
|
||||
|
||||
PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
#!/usr/bin/env python3
|
||||
# Convert Parakeet TDT model from NeMo format to ggml format
|
||||
#
|
||||
# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32]
|
||||
#
|
||||
# The NeMo file is a tar archive containing:
|
||||
# - model_weights.ckpt (PyTorch checkpoint)
|
||||
# - model_config.yaml (model configuration)
|
||||
# - tokenizer files
|
||||
#
|
||||
# This script extracts the NeMo archive, loads the model weights and configuration,
|
||||
# and saves them in ggml format compatible with whisper.cpp.
|
||||
#
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import yaml
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
def hz_to_mel(freq):
|
||||
return 2595.0 * np.log10(1.0 + freq / 700.0)
|
||||
|
||||
def mel_to_hz(mel):
|
||||
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
|
||||
|
||||
def extract_nemo_archive(nemo_path, extract_dir):
|
||||
print(f"Extracting {nemo_path} to {extract_dir}")
|
||||
with tarfile.open(nemo_path, 'r') as tar:
|
||||
tar.extractall(path=extract_dir)
|
||||
print("Extraction complete")
|
||||
|
||||
def load_model_config(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
def load_tokenizer(extract_dir, config):
|
||||
tokenizer_model_path = None
|
||||
tokenizer_vocab_path = None
|
||||
|
||||
for file in os.listdir(extract_dir):
|
||||
if file.endswith('_tokenizer.model'):
|
||||
tokenizer_model_path = os.path.join(extract_dir, file)
|
||||
elif file.endswith('tokenizer.vocab'):
|
||||
tokenizer_vocab_path = os.path.join(extract_dir, file)
|
||||
|
||||
if not tokenizer_model_path:
|
||||
raise FileNotFoundError("Tokenizer model file not found")
|
||||
|
||||
if not tokenizer_vocab_path:
|
||||
raise FileNotFoundError("Tokenizer vocab file not found")
|
||||
|
||||
tokens = {}
|
||||
with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
parts = line.strip().split('\t')
|
||||
if len(parts) >= 1:
|
||||
token = parts[0]
|
||||
tokens[token.encode('utf-8')] = idx
|
||||
|
||||
print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}")
|
||||
|
||||
if len(tokens) != 8192:
|
||||
print(f"WARNING: Expected 8192 tokens, got {len(tokens)}")
|
||||
|
||||
return tokens
|
||||
|
||||
def write_tensor(fout, name, data, use_f16=True, force_f32=False):
|
||||
if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1:
|
||||
data = data.reshape(1, -1, 1, 1)
|
||||
print(f" Reshaped conv bias {name} to {data.shape}")
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
ftype = 1 if use_f16 and not force_f32 else 0
|
||||
if force_f32:
|
||||
data = data.astype(np.float32)
|
||||
elif use_f16:
|
||||
if n_dims < 2 or 'bias' in name or 'norm' in name or \
|
||||
('pre_encode.conv' in name and n_dims == 4) or \
|
||||
'depthwise_conv.weight' in name:
|
||||
data = data.astype(np.float32)
|
||||
ftype = 0
|
||||
else:
|
||||
data = data.astype(np.float16)
|
||||
else:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)]
|
||||
print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}")
|
||||
name_bytes = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(name_bytes)
|
||||
|
||||
data.tofile(fout)
|
||||
|
||||
def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None):
|
||||
nemo_path = Path(nemo_path)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create temporary directory for extraction
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
extract_nemo_archive(nemo_path, temp_dir)
|
||||
|
||||
config_path = os.path.join(temp_dir, 'model_config.yaml')
|
||||
config = load_model_config(config_path)
|
||||
|
||||
print("Model configuration:")
|
||||
print(f" Sample rate: {config['sample_rate']}")
|
||||
print(f" Encoder layers: {config['encoder']['n_layers']}")
|
||||
print(f" Encoder d_model: {config['encoder']['d_model']}")
|
||||
print(f" Mel features: {config['preprocessor']['features']}")
|
||||
|
||||
weights_path = os.path.join(temp_dir, 'model_weights.ckpt')
|
||||
print(f"\nLoading model weights from {weights_path}")
|
||||
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||
|
||||
# Extract state dict
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
print(f"Loaded {len(state_dict)} tensors")
|
||||
|
||||
# Load tokenizer
|
||||
print("\nLoading tokenizer...")
|
||||
tokens = load_tokenizer(temp_dir, config)
|
||||
print(f"Loaded {len(tokens)} tokens")
|
||||
|
||||
# Prepare hyperparameters for the Parakeet ggml format.
|
||||
hparams = {
|
||||
'n_audio_ctx': 5000,
|
||||
'n_audio_state': config['encoder']['d_model'],
|
||||
'n_audio_head': config['encoder']['n_heads'],
|
||||
'n_audio_layer': config['encoder']['n_layers'],
|
||||
'n_mels': config['preprocessor']['features'],
|
||||
'n_fft': config['preprocessor']['n_fft'],
|
||||
'subsampling_factor': config['encoder']['subsampling_factor'],
|
||||
'n_subsampling_channels': config['encoder']['subsampling_conv_channels'],
|
||||
'n_conv_kernel': config['encoder']['conv_kernel_size'],
|
||||
|
||||
'n_pred_dim': config['decoder']['prednet']['pred_hidden'],
|
||||
'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'],
|
||||
'n_vocab': config['decoder']['vocab_size'],
|
||||
'n_tdt_durations': config['model_defaults']['num_tdt_durations'],
|
||||
'n_max_tokens': config['decoding']['greedy']['max_symbols'],
|
||||
}
|
||||
|
||||
print("\nGGML hyperparameters:")
|
||||
for key, value in hparams.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Create output file
|
||||
if out_name:
|
||||
fname_out = output_dir / out_name
|
||||
else:
|
||||
fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin")
|
||||
print(f"\nWriting to {fname_out}")
|
||||
|
||||
with open(fname_out, 'wb') as fout:
|
||||
# Write magic number
|
||||
fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex
|
||||
|
||||
# Write hyperparameters
|
||||
fout.write(struct.pack("i", hparams['n_vocab']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_ctx']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_state']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_head']))
|
||||
fout.write(struct.pack("i", hparams['n_audio_layer']))
|
||||
fout.write(struct.pack("i", hparams['n_mels']))
|
||||
fout.write(struct.pack("i", 1 if use_f16 else 0))
|
||||
fout.write(struct.pack("i", hparams['n_fft']))
|
||||
fout.write(struct.pack("i", hparams['subsampling_factor']))
|
||||
fout.write(struct.pack("i", hparams['n_subsampling_channels']))
|
||||
fout.write(struct.pack("i", hparams['n_conv_kernel']))
|
||||
fout.write(struct.pack("i", hparams['n_pred_dim']))
|
||||
fout.write(struct.pack("i", hparams['n_pred_layers']))
|
||||
fout.write(struct.pack("i", hparams['n_tdt_durations']))
|
||||
fout.write(struct.pack("i", hparams['n_max_tokens']))
|
||||
|
||||
# Extract mel filterbank from model
|
||||
fb_key = None
|
||||
for key in state_dict.keys():
|
||||
if 'featurizer.fb' in key or 'filterbank' in key.lower():
|
||||
fb_key = key
|
||||
break
|
||||
|
||||
if not fb_key:
|
||||
print("\nERROR: Mel filterbank not found in model!")
|
||||
print("Expected tensor with 'featurizer.fb' or 'filterbank' in name")
|
||||
print("\nAvailable preprocessor tensors:")
|
||||
for key in sorted(state_dict.keys()):
|
||||
if 'preprocessor' in key or 'featurizer' in key:
|
||||
print(f" {key}: {state_dict[key].shape}")
|
||||
raise ValueError("Mel filterbank tensor not found in model")
|
||||
|
||||
print(f"\nUsing model's mel filterbank from: {fb_key}")
|
||||
mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32)
|
||||
print(f" Filterbank shape: {mel_filters.shape}")
|
||||
print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}")
|
||||
print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}")
|
||||
print(f" First row sum: {mel_filters[0].sum():.6f}")
|
||||
|
||||
if len(mel_filters.shape) != 2:
|
||||
raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}")
|
||||
|
||||
n_mels, n_freqs = mel_filters.shape
|
||||
fout.write(struct.pack("i", n_mels)) # n_mel
|
||||
fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins)
|
||||
|
||||
# Write mel filterbank
|
||||
for i in range(n_mels):
|
||||
for j in range(n_freqs):
|
||||
fout.write(struct.pack("f", mel_filters[i, j]))
|
||||
|
||||
# Extract window function from model
|
||||
window_key = None
|
||||
for key in state_dict.keys():
|
||||
if 'featurizer.window' in key or 'preproc' in key and 'window' in key:
|
||||
window_key = key
|
||||
break
|
||||
|
||||
if not window_key:
|
||||
print("\nERROR: Window function not found in model!")
|
||||
print("Expected tensor with 'featurizer.window' in name")
|
||||
raise ValueError("Window function tensor not found in model")
|
||||
|
||||
print(f"\nUsing model's window function from: {window_key}")
|
||||
window = state_dict[window_key].squeeze().numpy().astype(np.float32)
|
||||
print(f" Window shape: {window.shape}")
|
||||
print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}")
|
||||
print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}")
|
||||
print(f" Window sum: {window.sum():.6f}")
|
||||
|
||||
if len(window.shape) != 1:
|
||||
raise ValueError(f"Expected 1D window, got shape {window.shape}")
|
||||
|
||||
n_window = window.shape[0]
|
||||
fout.write(struct.pack("i", n_window))
|
||||
|
||||
# Write window function
|
||||
for i in range(n_window):
|
||||
fout.write(struct.pack("f", window[i]))
|
||||
|
||||
# Write TDT durations
|
||||
tdt_durations = config['model_defaults']['tdt_durations']
|
||||
if len(tdt_durations) != hparams['n_tdt_durations']:
|
||||
raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}")
|
||||
|
||||
for duration in tdt_durations:
|
||||
fout.write(struct.pack("I", duration))
|
||||
|
||||
fout.write(struct.pack("i", len(tokens)))
|
||||
for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]):
|
||||
fout.write(struct.pack("i", len(token_bytes)))
|
||||
fout.write(token_bytes)
|
||||
|
||||
# Pre-collect prediction LSTM input-hidden biases so they can be
|
||||
# folded into the hidden-hidden bias during the main write loop.
|
||||
lstm_prefix = 'decoder.prediction.dec_rnn.lstm'
|
||||
pred_bias_ih = {}
|
||||
for key, t in state_dict.items():
|
||||
if f'{lstm_prefix}.bias_ih_l' in key:
|
||||
layer_idx = int(key.rsplit('bias_ih_l', 1)[1])
|
||||
pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32)
|
||||
|
||||
print("\nConverting model weights...")
|
||||
for name, tensor in state_dict.items():
|
||||
# Skip the filterbank and window - already written in preprocessing section
|
||||
if name == fb_key:
|
||||
continue
|
||||
if name == window_key:
|
||||
continue
|
||||
|
||||
# bias_ih is folded into bias_hh below; skip writing it separately
|
||||
if f'{lstm_prefix}.bias_ih_l' in name:
|
||||
continue
|
||||
|
||||
# Don't squeeze Conv2d weights - they need to preserve all 4 dimensions
|
||||
if 'conv' in name and 'weight' in name and len(tensor.shape) == 4:
|
||||
data = tensor.numpy()
|
||||
else:
|
||||
data = tensor.squeeze().numpy()
|
||||
|
||||
# For prediction LSTM weights/biases:
|
||||
# Fold bias_ih into bias_hh (bias_ih already skipped above).
|
||||
# Reorder gates (input, forget, cell, output) from PyTorch layout
|
||||
# [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs
|
||||
# (i, f, o) are contiguous.
|
||||
if name.startswith(f'{lstm_prefix}.'):
|
||||
if f'{lstm_prefix}.bias_hh_l' in name:
|
||||
layer_idx = int(name.rsplit('bias_hh_l', 1)[1])
|
||||
data = data.astype(np.float32) + pred_bias_ih[layer_idx]
|
||||
name = name.replace('bias_hh_l', 'bias_h_l')
|
||||
h = data.shape[0] // 4
|
||||
data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
|
||||
|
||||
write_tensor(fout, name, data, use_f16=use_f16)
|
||||
|
||||
print(f"\nConversion complete!")
|
||||
print(f"Output file: {fname_out}")
|
||||
print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert Parakeet TDT model from NeMo format to ggml format'
|
||||
)
|
||||
parser.add_argument('--model', type=str, required=True,
|
||||
help='Path to Parakeet .nemo model file')
|
||||
parser.add_argument('--out-dir', type=str, required=True,
|
||||
help='Directory to write ggml model file')
|
||||
parser.add_argument('--use-f32', action='store_true', default=False,
|
||||
help='Use f32 instead of f16 (default: f16)')
|
||||
parser.add_argument('--out-name', type=str, default=None,
|
||||
help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model):
|
||||
print(f"Error: {args.model} not found")
|
||||
sys.exit(1)
|
||||
|
||||
use_f16 = not args.use_f32
|
||||
convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name)
|
||||
Binary file not shown.
|
|
@ -0,0 +1,182 @@
|
|||
#!/usr/bin/env python3
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
def write_tensor(fout, name, data):
|
||||
n_dims = len(data.shape)
|
||||
data = data.astype(np.float32)
|
||||
ftype = 0 # GGML_TYPE_F32
|
||||
|
||||
name_bytes = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(name_bytes)
|
||||
data.tofile(fout)
|
||||
|
||||
def generate(output_path):
|
||||
rng = np.random.default_rng(42)
|
||||
|
||||
hparams = {
|
||||
'n_vocab': 10,
|
||||
'n_audio_ctx': 3200,
|
||||
'n_audio_state': 8,
|
||||
'n_audio_head': 2,
|
||||
'n_audio_layer': 1,
|
||||
'n_mels': 16,
|
||||
'ftype': 0,
|
||||
'n_fft': 64,
|
||||
'subsampling_factor': 8,
|
||||
'n_subsampling_channels': 4,
|
||||
'n_conv_kernel': 3,
|
||||
'n_pred_dim': 8,
|
||||
'n_pred_layers': 1,
|
||||
'n_tdt_durations': 2,
|
||||
'n_max_tokens': 5,
|
||||
}
|
||||
|
||||
n_vocab = hparams['n_vocab']
|
||||
n_state = hparams['n_audio_state']
|
||||
n_head = hparams['n_audio_head']
|
||||
n_layer = hparams['n_audio_layer']
|
||||
n_mels = hparams['n_mels']
|
||||
n_fft = hparams['n_fft']
|
||||
n_sub_fac = hparams['subsampling_factor']
|
||||
n_sub_ch = hparams['n_subsampling_channels']
|
||||
n_conv_ker = hparams['n_conv_kernel']
|
||||
dec_dim = hparams['n_pred_dim']
|
||||
n_pred_l = hparams['n_pred_layers']
|
||||
n_tdt = hparams['n_tdt_durations']
|
||||
|
||||
n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch
|
||||
n_head_dim = n_state // n_head
|
||||
n_pred_embed = n_vocab + 1
|
||||
n_lstm_gates = 4 * dec_dim
|
||||
n_joint_out = n_vocab + n_tdt + 1
|
||||
n_freqs = n_fft // 2 + 1
|
||||
|
||||
def f32(*shape):
|
||||
return rng.standard_normal(shape).astype(np.float32)
|
||||
|
||||
with open(output_path, 'wb') as fout:
|
||||
fout.write(struct.pack("I", 0x67676d6c))
|
||||
|
||||
for key in ['n_vocab',
|
||||
'n_audio_ctx',
|
||||
'n_audio_state',
|
||||
'n_audio_head',
|
||||
'n_audio_layer',
|
||||
'n_mels',
|
||||
'ftype',
|
||||
'n_fft',
|
||||
'subsampling_factor',
|
||||
'n_subsampling_channels',
|
||||
'n_conv_kernel',
|
||||
'n_pred_dim',
|
||||
'n_pred_layers',
|
||||
'n_tdt_durations',
|
||||
'n_max_tokens']:
|
||||
fout.write(struct.pack("i", hparams[key]))
|
||||
|
||||
fout.write(struct.pack("i", n_mels))
|
||||
fout.write(struct.pack("i", n_freqs))
|
||||
f32(n_mels, n_freqs).tofile(fout)
|
||||
|
||||
fout.write(struct.pack("i", n_fft))
|
||||
f32(n_fft).tofile(fout)
|
||||
|
||||
for d in range(n_tdt):
|
||||
fout.write(struct.pack("I", d))
|
||||
|
||||
tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)]
|
||||
assert len(tokens) == n_vocab
|
||||
fout.write(struct.pack("i", n_vocab))
|
||||
for tok in tokens:
|
||||
tok_bytes = tok.encode('utf-8')
|
||||
fout.write(struct.pack("i", len(tok_bytes)))
|
||||
fout.write(tok_bytes)
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc))
|
||||
write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1))
|
||||
write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1))
|
||||
|
||||
for i in range(n_layer):
|
||||
p = f"encoder.layers.{i}"
|
||||
|
||||
write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state))
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state)))
|
||||
num_batches = np.zeros(1, dtype=np.int32)
|
||||
write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches)
|
||||
write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim))
|
||||
write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state))
|
||||
write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state))
|
||||
write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state))
|
||||
|
||||
write_tensor(fout, f"{p}.norm_out.weight", f32(n_state))
|
||||
write_tensor(fout, f"{p}.norm_out.bias", f32(n_state))
|
||||
|
||||
write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim))
|
||||
|
||||
def reorder_gates(data):
|
||||
h = data.shape[0] // 4
|
||||
return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
|
||||
|
||||
for i in range(n_pred_l):
|
||||
base = f"decoder.prediction.dec_rnn.lstm"
|
||||
write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
|
||||
write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim)))
|
||||
write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates)))
|
||||
|
||||
write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim))
|
||||
write_tensor(fout, "joint.pred.bias", f32(dec_dim))
|
||||
write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state))
|
||||
write_tensor(fout, "joint.enc.bias", f32(dec_dim))
|
||||
write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim))
|
||||
write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out))
|
||||
|
||||
size = Path(output_path).stat().st_size
|
||||
print(f"Generated {output_path} ({size / 1024:.1f} KB)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin'
|
||||
generate(output)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
torch
|
||||
numpy
|
||||
pyyaml
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
build_dir=build
|
||||
modelname=ggml-parakeet-tdt-0.6b-v3
|
||||
model=models/${modelname}-f32.bin
|
||||
cmd=parakeet-quantize
|
||||
|
||||
cmake --build ${build_dir} --target $cmd -j 12
|
||||
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k
|
||||
${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
import argparse
|
||||
import os
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
|
||||
USER_NAME = "ggml-org"
|
||||
REPO_ID = f"{USER_NAME}/parakeet-GGUF"
|
||||
|
||||
MODELS = {
|
||||
"f32": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin",
|
||||
"description": "Full precision (F32)",
|
||||
},
|
||||
"f16": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin",
|
||||
"description": "Half precision (F16)",
|
||||
},
|
||||
"q8_0": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin",
|
||||
"description": "8-bit quantized (Q8_0)",
|
||||
},
|
||||
"q4_0": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin",
|
||||
"description": "4-bit quantized (Q4_0)",
|
||||
},
|
||||
"q4_k": {
|
||||
"local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
|
||||
"remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin",
|
||||
"description": "4-bit K-quantized (Q4_k)",
|
||||
},
|
||||
}
|
||||
|
||||
def build_model_card(uploaded_variants):
|
||||
lines = [
|
||||
f"---",
|
||||
f"license: mit",
|
||||
f"base_model: nvidia/parakeet-tdt-0.6b-v3",
|
||||
f"tags:",
|
||||
f"- gguf",
|
||||
f"- asr",
|
||||
f"---",
|
||||
f"",
|
||||
f"# Parakeet TDT 0.6B v3 (GGUF)",
|
||||
f"",
|
||||
f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).",
|
||||
f"",
|
||||
f"## Available files",
|
||||
f"",
|
||||
]
|
||||
|
||||
for key, m in MODELS.items():
|
||||
if key in uploaded_variants:
|
||||
lines.append(f"- `{m['remote_name']}` — {m['description']}")
|
||||
|
||||
lines += [
|
||||
f"",
|
||||
f"## Usage",
|
||||
f"",
|
||||
f"Build parakeet-cli:",
|
||||
f"```console",
|
||||
f"git clone https://github.com/ggml-org/whisper.cpp.git",
|
||||
f"cd whisper.cpp",
|
||||
f"cmake -B build -S .",
|
||||
f"cmake --build build --target parakeet-cli -j $(nproc)",
|
||||
f"```",
|
||||
f"",
|
||||
f"Download a model (e.g. Q8_0):",
|
||||
f"```console",
|
||||
f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models",
|
||||
f"```",
|
||||
f"",
|
||||
f"Run:",
|
||||
f"```console",
|
||||
f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav",
|
||||
f"```",
|
||||
f"",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def upload_variant(api, key):
|
||||
m = MODELS[key]
|
||||
local_path = m["local_path"]
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
print(f" Skipping {key}: {local_path} not found")
|
||||
return False
|
||||
|
||||
print(f" Uploading {m['remote_name']} ({m['description']})...")
|
||||
api.upload_file(
|
||||
path_or_fileobj=local_path,
|
||||
path_in_repo=m["remote_name"],
|
||||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
commit_message=f"Upload {m['remote_name']}",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face")
|
||||
parser.add_argument(
|
||||
"variants",
|
||||
nargs="*",
|
||||
default=None,
|
||||
metavar="{" + ",".join(MODELS.keys()) + "}",
|
||||
help="Model variants to upload (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-model-card",
|
||||
action="store_true",
|
||||
help="Skip updating the model card README",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
api = HfApi()
|
||||
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
||||
|
||||
variants = args.variants if args.variants else list(MODELS.keys())
|
||||
|
||||
unknown = [v for v in variants if v not in MODELS]
|
||||
if unknown:
|
||||
parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})")
|
||||
|
||||
uploaded = []
|
||||
for key in variants:
|
||||
if upload_variant(api, key):
|
||||
uploaded.append(key)
|
||||
|
||||
if not uploaded:
|
||||
print("No models were uploaded.")
|
||||
return
|
||||
|
||||
if not args.no_model_card:
|
||||
print("Updating model card...")
|
||||
existing = [k for k in MODELS if k in uploaded or
|
||||
any(f.rfilename == MODELS[k]["remote_name"]
|
||||
for f in api.list_repo_files(REPO_ID, repo_type="model")
|
||||
if hasattr(f, "rfilename"))]
|
||||
card = build_model_card(existing if existing else uploaded)
|
||||
api.upload_file(
|
||||
path_or_fileobj=card.encode(),
|
||||
path_in_repo="README.md",
|
||||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
commit_message="Update README.md",
|
||||
)
|
||||
|
||||
print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -109,23 +109,43 @@ add_library(whisper
|
|||
whisper.cpp
|
||||
)
|
||||
|
||||
add_library(parakeet
|
||||
../include/parakeet.h
|
||||
parakeet-arch.h
|
||||
parakeet.cpp
|
||||
)
|
||||
|
||||
target_include_directories(parakeet PUBLIC . ../include)
|
||||
target_compile_features (parakeet PUBLIC cxx_std_11)
|
||||
target_link_libraries(parakeet PUBLIC ggml Threads::Threads)
|
||||
|
||||
# Set the version numbers
|
||||
set_target_properties(whisper PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
SOVERSION ${SOVERSION}
|
||||
)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
SOVERSION ${SOVERSION}
|
||||
)
|
||||
|
||||
target_include_directories(whisper PUBLIC . ../include)
|
||||
target_compile_features (whisper PUBLIC cxx_std_11) # don't bump
|
||||
|
||||
if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN")
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN)
|
||||
set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN)
|
||||
endif()
|
||||
|
||||
if (WHISPER_EXTRA_FLAGS)
|
||||
target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS})
|
||||
endif()
|
||||
|
||||
if (PARAKEET_EXTRA_FLAGS)
|
||||
target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS})
|
||||
endif()
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
target_link_libraries(whisper PUBLIC ggml Threads::Threads)
|
||||
|
||||
|
|
@ -144,4 +164,7 @@ endif()
|
|||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD)
|
||||
|
||||
set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,188 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
enum parakeet_tensor {
|
||||
// Encoder pre_encode
|
||||
PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_OUT_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS,
|
||||
|
||||
// Encoder layers (per-layer)
|
||||
PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF1_BIAS,
|
||||
PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_CONV_BIAS,
|
||||
PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_BIAS,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_MEAN,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_VAR,
|
||||
PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES,
|
||||
PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V,
|
||||
PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_FF2_BIAS,
|
||||
PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT,
|
||||
PARAKEET_TENSOR_ENC_NORM_OUT_BIAS,
|
||||
|
||||
// Prediction network
|
||||
PARAKEET_TENSOR_PRED_EMBED_WEIGHT,
|
||||
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH,
|
||||
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH,
|
||||
PARAKEET_TENSOR_PRED_LSTM_BIAS_H,
|
||||
|
||||
// Joint network
|
||||
PARAKEET_TENSOR_JOINT_PRED_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_PRED_BIAS,
|
||||
PARAKEET_TENSOR_JOINT_ENC_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_ENC_BIAS,
|
||||
PARAKEET_TENSOR_JOINT_NET_WEIGHT,
|
||||
PARAKEET_TENSOR_JOINT_NET_BIAS,
|
||||
};
|
||||
|
||||
static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = {
|
||||
// Encoder pre_encode
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"},
|
||||
|
||||
// Encoder layers (use %d for layer number)
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"},
|
||||
|
||||
// Prediction network
|
||||
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"},
|
||||
|
||||
// Joint network
|
||||
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"},
|
||||
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"},
|
||||
{PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"},
|
||||
};
|
||||
|
||||
static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = {
|
||||
// Encoder pre_encode
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD},
|
||||
|
||||
// Encoder layers
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV},
|
||||
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE},
|
||||
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL},
|
||||
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD},
|
||||
|
||||
// Prediction network
|
||||
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD},
|
||||
|
||||
// Joint network
|
||||
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD},
|
||||
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT},
|
||||
{PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD},
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -118,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE
|
|||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
|
||||
add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST})
|
||||
set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en")
|
||||
|
||||
# Parakeet model loading test
|
||||
set(PARAKEET_TEST test-parakeet)
|
||||
add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp)
|
||||
target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common)
|
||||
target_compile_definitions(${PARAKEET_TEST} PRIVATE
|
||||
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin"
|
||||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav")
|
||||
add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST})
|
||||
set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh")
|
||||
|
||||
# The following parakeet test require a real ggml-parakeet-tdt model to have
|
||||
# been converted or downloaded:
|
||||
# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models
|
||||
#
|
||||
# And also required more audio samples that are shipped by default. These can
|
||||
# downloaded by running:
|
||||
# $ make samples
|
||||
function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH)
|
||||
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0")
|
||||
if (ARGC GREATER 4)
|
||||
set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}")
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCE})
|
||||
target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE parakeet common)
|
||||
target_compile_definitions(${TEST_TARGET} PRIVATE
|
||||
PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin"
|
||||
SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}"
|
||||
EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}"
|
||||
TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD})
|
||||
|
||||
add_custom_target(run-${TEST_TARGET}
|
||||
COMMAND $<TARGET_FILE:${TEST_TARGET}>
|
||||
DEPENDS ${TEST_TARGET}
|
||||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
|
||||
endfunction()
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-jfk
|
||||
test-parakeet-full.cpp
|
||||
samples/jfk.wav
|
||||
tests/parakeet-expected-jfk-output.txt)
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-gb1
|
||||
test-parakeet-full.cpp
|
||||
samples/gb1.wav
|
||||
tests/parakeet-expected-gb1-output.txt)
|
||||
|
||||
add_parakeet_transcription_test(
|
||||
test-parakeet-full-diffusion
|
||||
test-parakeet-full.cpp
|
||||
samples/diffusion2023-07-03.flac
|
||||
tests/parakeet-expected-diffusion-output.txt
|
||||
0.95)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
__pycache__
|
||||
*.tar.gz
|
||||
*.txt
|
||||
eval.conf
|
||||
venv
|
||||
LibriSpeech
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz
|
||||
|
||||
all: eval
|
||||
|
||||
eval:
|
||||
$(MAKE) -f eval.mk
|
||||
|
||||
clean:
|
||||
$(MAKE) -f eval.mk clean
|
||||
|
||||
get-audio:
|
||||
wget -c $(TAR_URL)
|
||||
tar -xf test-clean.tar.gz
|
||||
|
||||
.PHONY: all eval clean setup-venv clean-venv get-audio
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# parakeet.cpp/tests/librispeech
|
||||
|
||||
[LibriSpeech](https://www.openslr.org/12) is a standard dataset for
|
||||
training and evaluating automatic speech recognition systems.
|
||||
|
||||
This directory contains a set of tools to evaluate the recognition
|
||||
performance of parakeet.cpp on LibriSpeech corpus.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet
|
||||
model in `ggml` format.
|
||||
|
||||
```
|
||||
$ # Execute the commands below in the project root dir.
|
||||
$ cmake -B build
|
||||
$ cmake --build build --config Release
|
||||
```
|
||||
|
||||
2. Download the audio files from LibriSpeech project.
|
||||
|
||||
```
|
||||
$ make get-audio
|
||||
```
|
||||
|
||||
3. Set up the environment to compute WER score.
|
||||
|
||||
```
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
For example, if you use `virtualenv`, you can set up it as follows:
|
||||
|
||||
```
|
||||
$ python3 -m venv venv
|
||||
$ . venv/bin/activate
|
||||
$ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Run the benchmark test.
|
||||
|
||||
```
|
||||
$ make
|
||||
```
|
||||
|
||||
## How-to guides
|
||||
|
||||
### How to change the inference parameters
|
||||
|
||||
Create `eval.conf` and override variables.
|
||||
|
||||
```
|
||||
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
|
||||
PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt
|
||||
```
|
||||
|
||||
Check out `eval.mk` for more details.
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
PYTHON = python
|
||||
|
||||
PARAKEET_PREFIX = ../../
|
||||
PARAKEET_MODEL = parakeet-tdt-0.6b-v3
|
||||
|
||||
PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli
|
||||
PARAKEET_FLAGS = --no-prints --output-txt
|
||||
|
||||
# You can create eval.conf to override the PARAKEET_* variables
|
||||
# defined above.
|
||||
-include eval.conf
|
||||
|
||||
# This follows the file structure of the LibriSpeech project.
|
||||
AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac))
|
||||
TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS))
|
||||
|
||||
# We output the evaluation result to this file.
|
||||
DONE = $(PARAKEET_MODEL).txt
|
||||
|
||||
all: $(DONE)
|
||||
|
||||
$(DONE): $(TRANS_TXTS)
|
||||
$(PYTHON) eval.py > $@.tmp
|
||||
mv $@.tmp $@
|
||||
|
||||
# Note: This task writes to a temporary file first to
|
||||
# create the target file atomically.
|
||||
%.flac.txt: %.flac
|
||||
$(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp
|
||||
mv $^.tmp.txt $^.txt
|
||||
|
||||
archive:
|
||||
tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE)
|
||||
|
||||
clean:
|
||||
@rm -f $(TRANS_TXTS)
|
||||
@rm -f $(DONE)
|
||||
|
||||
.PHONY: all clean
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import glob
|
||||
import jiwer
|
||||
from normalizers import EnglishTextNormalizer
|
||||
|
||||
def get_reference():
|
||||
ref = {}
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'):
|
||||
with open(path) as fp:
|
||||
for line in fp:
|
||||
code, text = line.strip().split(" ", maxsplit=1)
|
||||
ref [code] = text
|
||||
return ref
|
||||
|
||||
def get_hypothesis():
|
||||
hyp = {}
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'):
|
||||
with open(path) as fp:
|
||||
text = fp.read().strip()
|
||||
code = os.path.basename(path).replace('.flac.txt', '')
|
||||
hyp[code] = text
|
||||
return hyp
|
||||
|
||||
def get_codes():
|
||||
codes = []
|
||||
for path in glob.glob('LibriSpeech/*/*/*/*.flac'):
|
||||
codes.append(os.path.basename(path).replace('.flac', ''))
|
||||
return sorted(codes)
|
||||
|
||||
def main():
|
||||
normalizer = EnglishTextNormalizer()
|
||||
|
||||
ref_orig = get_reference()
|
||||
hyp_orig = get_hypothesis()
|
||||
|
||||
ref_clean = []
|
||||
hyp_clean = []
|
||||
|
||||
for code in get_codes():
|
||||
ref_clean.append(normalizer(ref_orig[code]))
|
||||
hyp_clean.append(normalizer(hyp_orig[code]))
|
||||
|
||||
wer = jiwer.wer(ref_clean, hyp_clean)
|
||||
print(f"WER: {wer * 100:.2f}%")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
Code in this directory is adapted from OpenAI Whisper project
|
||||
(https://github.com/openai/whisper) and carries the following
|
||||
copyright and license.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
(
|
||||
c
|
||||
if c in keep
|
||||
else (
|
||||
ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else (
|
||||
""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||
)
|
||||
)
|
||||
)
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,550 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1 @@
|
|||
My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America.
|
||||
|
|
@ -0,0 +1 @@
|
|||
And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD
|
||||
#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0
|
||||
#endif
|
||||
|
||||
static std::string read_expected_transcription(const char * path) {
|
||||
std::ifstream fin(path);
|
||||
assert(fin.is_open());
|
||||
|
||||
std::string text(
|
||||
(std::istreambuf_iterator<char>(fin)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) {
|
||||
text.pop_back();
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::vector<std::string> transcription_words(const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
std::string word;
|
||||
|
||||
for (unsigned char ch : text) {
|
||||
if (std::isalnum(ch)) {
|
||||
word.push_back((char) std::tolower(ch));
|
||||
} else if (!word.empty()) {
|
||||
words.push_back(word);
|
||||
word.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!word.empty()) {
|
||||
words.push_back(word);
|
||||
}
|
||||
|
||||
return words;
|
||||
}
|
||||
|
||||
static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) {
|
||||
const std::vector<std::string> expected_words = transcription_words(expected);
|
||||
const std::vector<std::string> actual_words = transcription_words(actual);
|
||||
|
||||
if (expected_words.empty() && actual_words.empty()) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
if (expected_words.empty() || actual_words.empty()) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
std::vector<int> prev(actual_words.size() + 1, 0);
|
||||
std::vector<int> cur (actual_words.size() + 1, 0);
|
||||
|
||||
for (size_t i = 0; i < expected_words.size(); ++i) {
|
||||
std::fill(cur.begin(), cur.end(), 0);
|
||||
|
||||
for (size_t j = 0; j < actual_words.size(); ++j) {
|
||||
if (expected_words[i] == actual_words[j]) {
|
||||
cur[j + 1] = prev[j] + 1;
|
||||
} else {
|
||||
cur[j + 1] = std::max(prev[j + 1], cur[j]);
|
||||
}
|
||||
}
|
||||
|
||||
prev.swap(cur);
|
||||
}
|
||||
|
||||
const int lcs = prev[actual_words.size()];
|
||||
return (2.0 * lcs) / (expected_words.size() + actual_words.size());
|
||||
}
|
||||
|
||||
static bool verify_transcription(const std::string & expected, const std::string & actual) {
|
||||
const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD;
|
||||
|
||||
if (threshold >= 1.0) {
|
||||
if (actual == expected) {
|
||||
return true;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
fprintf(stderr, "[Failed] Transcript mismatched\n");
|
||||
fprintf(stderr, "expected:\n%s\n\n", expected.c_str());
|
||||
fprintf(stderr, "actual:\n%s\n", actual.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
const double similarity = transcription_lcs_similarity(expected, actual);
|
||||
printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold);
|
||||
|
||||
if (similarity >= threshold) {
|
||||
return true;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold);
|
||||
fprintf(stderr, "Expected:\n%s\n\n", expected.c_str());
|
||||
fprintf(stderr, "Actual:\n%s\n", actual.c_str());
|
||||
return false;
|
||||
}
|
||||
|
|
@ -21,13 +21,21 @@ cd `dirname $0`
|
|||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" )
|
||||
|
||||
# Parakeet model variants
|
||||
parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
printf "\n"
|
||||
printf " Available models:"
|
||||
printf " Available whisper models:"
|
||||
for model in "${models[@]}"; do
|
||||
printf " $model"
|
||||
done
|
||||
printf "\n"
|
||||
printf " Available parakeet models:"
|
||||
for model in "${parakeet_models[@]}"; do
|
||||
printf " parakeet-$model"
|
||||
done
|
||||
printf "\n\n"
|
||||
}
|
||||
|
||||
|
|
@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then
|
|||
fi
|
||||
|
||||
model=$1
|
||||
main="../build/bin/whisper-cli"
|
||||
|
||||
threads=""
|
||||
if [ $# -eq 2 ]; then
|
||||
threads="-t $2"
|
||||
fi
|
||||
|
||||
if [ ! -f ../models/ggml-$model.bin ]; then
|
||||
printf "Model $model not found. Aborting\n"
|
||||
# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32")
|
||||
is_parakeet=0
|
||||
parakeet_variant=""
|
||||
if [[ $model == parakeet-* ]]; then
|
||||
is_parakeet=1
|
||||
parakeet_variant="${model#parakeet-}"
|
||||
fi
|
||||
for v in "${parakeet_models[@]}"; do
|
||||
if [[ $model == "$v" ]]; then
|
||||
is_parakeet=1
|
||||
parakeet_variant="$v"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $is_parakeet -eq 1 ]; then
|
||||
main="../build/bin/parakeet-cli"
|
||||
model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin"
|
||||
else
|
||||
main="../build/bin/whisper-cli"
|
||||
model_path="../models/ggml-${model}.bin"
|
||||
fi
|
||||
|
||||
if [ ! -f $model_path ]; then
|
||||
printf "Model $model not found ($model_path). Aborting\n"
|
||||
list_models
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -110,7 +140,11 @@ function run_lang() {
|
|||
fi
|
||||
fi
|
||||
|
||||
$main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null
|
||||
if [ $is_parakeet -eq 1 ]; then
|
||||
$main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null
|
||||
else
|
||||
$main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null
|
||||
fi
|
||||
|
||||
git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt
|
||||
|
||||
|
|
@ -120,7 +154,7 @@ function run_lang() {
|
|||
|
||||
run_lang "en" "${urls_en[@]}"
|
||||
|
||||
if [[ $model != *.en* ]]; then
|
||||
if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then
|
||||
run_lang "es" "${urls_es[@]}"
|
||||
run_lang "it" "${urls_it[@]}"
|
||||
run_lang "pt" "${urls_pt[@]}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
#include "parakeet-verification.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
#include <cassert>
|
||||
|
||||
struct test_state {
|
||||
bool is_first = true;
|
||||
std::string transcript;
|
||||
};
|
||||
|
||||
void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
}
|
||||
|
||||
bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool abort_callback(void * user_data) {
|
||||
bool * called = static_cast<bool *>(user_data);
|
||||
*called = true;
|
||||
return false; // just continue without aborting.
|
||||
}
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
test_state * tstate = static_cast<test_state *>(user_data);
|
||||
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf));
|
||||
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
tstate->transcript += text_buf;
|
||||
tstate->is_first = false;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::string model_path = PARAKEET_MODEL_PATH;
|
||||
std::string sample_path = SAMPLE_PATH;
|
||||
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
assert(pcmf32.size() > 0);
|
||||
assert(pcmf32s.size() == 0); // no stereo vector
|
||||
|
||||
printf("Loading Parakeet model from: %s\n", model_path.c_str());
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "Failed to load Parakeet model\n");
|
||||
return 1;
|
||||
}
|
||||
printf("Successfully loaded Parakeet model\n");
|
||||
|
||||
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
test_state tstate;
|
||||
params.new_token_callback = token_callback;
|
||||
params.new_token_callback_user_data = &tstate;
|
||||
bool progress_callback_called = false;
|
||||
params.progress_callback = progress_callback;
|
||||
params.progress_callback_user_data = &progress_callback_called;
|
||||
bool encoder_begin_callback_called = false;
|
||||
params.encoder_begin_callback = encoder_begin_callback;
|
||||
params.encoder_begin_callback_user_data = &encoder_begin_callback_called;
|
||||
bool abort_callback_called = false;
|
||||
params.abort_callback = abort_callback;
|
||||
params.abort_callback_user_data = &abort_callback_called;
|
||||
|
||||
int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size());
|
||||
assert(ret == 0);
|
||||
assert(progress_callback_called);
|
||||
assert(encoder_begin_callback_called);
|
||||
assert(abort_callback_called);
|
||||
|
||||
const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH);
|
||||
const bool transcript_matches = verify_transcription(expected, tstate.transcript);
|
||||
|
||||
parakeet_free(pctx);
|
||||
|
||||
if (!transcript_matches) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("\nTest passed: parakeet_full succeeded!\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
#include "parakeet.h"
|
||||
#include "common-whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
#include <cassert>
|
||||
|
||||
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
|
||||
static bool is_first = true;
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
|
||||
char text_buf[256];
|
||||
parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf));
|
||||
|
||||
int32_t time_ms = token_data->frame_index * 10;
|
||||
|
||||
printf("%s", text_buf);
|
||||
fflush(stdout);
|
||||
|
||||
is_first = false;
|
||||
}
|
||||
|
||||
void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) {
|
||||
const int n_segments = parakeet_full_n_segments_from_state(state);
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
printf("\nSegment Callback: %d new segment(s)\n", n_new);
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
const char * text = parakeet_full_get_segment_text_from_state(state, i);
|
||||
const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i);
|
||||
const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i);
|
||||
|
||||
printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
|
||||
printf("Tokens:\n");
|
||||
|
||||
const int n_tokens = parakeet_full_n_tokens_from_state(state, i);
|
||||
for (int j = 0; j < n_tokens; j++) {
|
||||
parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j);
|
||||
const char * token_str = parakeet_token_to_str(ctx, token_data.id);
|
||||
|
||||
printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n",
|
||||
j,
|
||||
token_data.id,
|
||||
token_data.frame_index,
|
||||
token_data.duration_idx,
|
||||
token_data.duration_value,
|
||||
token_data.p,
|
||||
token_data.plog,
|
||||
(long long)token_data.t0,
|
||||
(long long)token_data.t1,
|
||||
token_data.is_word_start,
|
||||
token_str);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::string model_path = PARAKEET_MODEL_PATH;
|
||||
std::string sample_path = SAMPLE_PATH;
|
||||
|
||||
// Load the sample audio file
|
||||
std::vector<float> pcmf32;
|
||||
std::vector<std::vector<float>> pcmf32s;
|
||||
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
|
||||
assert(pcmf32.size() > 0);
|
||||
assert(pcmf32s.size() == 0);
|
||||
|
||||
printf("Loading Parakeet model from: %s\n", model_path.c_str());
|
||||
|
||||
struct parakeet_context_params ctx_params = parakeet_context_default_params();
|
||||
|
||||
struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params);
|
||||
if (pctx == nullptr) {
|
||||
fprintf(stderr, "Failed to load Parakeet model\n");
|
||||
return 1;
|
||||
}
|
||||
printf("Successfully loaded Parakeet model\n");
|
||||
|
||||
struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
|
||||
params.new_token_callback = token_callback;
|
||||
params.new_token_callback_user_data = nullptr;
|
||||
params.new_segment_callback = segment_callback;
|
||||
params.new_segment_callback_user_data = nullptr;
|
||||
parakeet_state * state = parakeet_init_state(pctx);
|
||||
|
||||
int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size());
|
||||
assert(ret == 0);
|
||||
|
||||
parakeet_free_state(state);
|
||||
parakeet_free(pctx);
|
||||
|
||||
printf("\nTest passed: Parakeet model loaded and freed successfully\n");
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue