whisper.cpp/examples/parakeet-cli/parakeet-cli.cpp

244 lines
9.8 KiB
C++

#include "parakeet.h"
#include "common-whisper.h"
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#include <fstream>
// command-line parameters
struct parakeet_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
bool use_gpu = true;
int32_t gpu_device = 0;
bool print_segments = false;
bool output_txt = false;
bool no_prints = false;
std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin";
std::string output_file = "";
std::vector<std::string> fname_inp = {};
};
static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params);
static char * requires_value_error(const std::string & arg) {
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
exit(1);
}
static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) {
if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) {
params.gpu_device = std::stoi(env_device);
}
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}
if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;
}
if (arg == "-h" || arg == "--help") {
parakeet_print_usage(argc, argv, params);
exit(0);
}
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); }
else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
parakeet_print_usage(argc, argv, params);
exit(1);
}
}
return true;
}
static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]);
fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n");
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", "");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device);
fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, "\n");
}
void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) {
bool * is_first = (bool *) user_data;
const char * token_str = parakeet_token_to_str(ctx, token_data->id);
char text_buf[256];
parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf));
printf("%s", text_buf);
fflush(stdout);
*is_first = false;
}
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int main(int argc, char ** argv) {
ggml_backend_load_all();
parakeet_params params;
if (parakeet_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.no_prints) {
parakeet_log_set(cb_log_disable, NULL);
}
if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
parakeet_print_usage(argc, argv, params);
return 1;
}
struct parakeet_context_params ctx_params = parakeet_context_default_params();
ctx_params.use_gpu = params.use_gpu;
ctx_params.gpu_device = params.gpu_device;
if (!params.no_prints) {
fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str());
}
struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params);
if (pctx == nullptr) {
fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str());
return 1;
}
if (!params.no_prints) {
fprintf(stderr, "Successfully loaded Parakeet model\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info());
}
// Process each input file
for (const auto & fname : params.fname_inp) {
if (!params.no_prints) {
fprintf(stderr, "\nProcessing file: %s\n", fname.c_str());
}
std::vector<float> pcmf32;
std::vector<std::vector<float>> pcmf32s;
if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str());
continue;
}
if (pcmf32.empty()) {
fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str());
continue;
}
bool is_first = true;
struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY);
full_params.n_threads = params.n_threads;
full_params.new_token_callback = token_callback;
full_params.new_token_callback_user_data = &is_first;
const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH);
int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size());
if (ret != 0) {
fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str());
continue;
}
printf("\n");
if (params.output_txt) {
const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt";
std::ofstream fout(fname_out);
if (fout.is_open()) {
const int n_segments = parakeet_full_n_segments(pctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = parakeet_full_get_segment_text(pctx, i);
fout << text << "\n";
}
fout.close();
if (!params.no_prints) {
fprintf(stderr, "Output written to: %s\n", fname_out.c_str());
}
} else {
fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str());
}
}
if (!params.no_prints) {
parakeet_print_timings(pctx);
}
if (params.print_segments) {
const int n_segments = parakeet_full_n_segments(pctx);
fprintf(stderr, "\nSegments (%d):\n", n_segments);
for (int i = 0; i < n_segments; i++) {
const char * text = parakeet_full_get_segment_text(pctx, i);
const int64_t t0 = parakeet_full_get_segment_t0(pctx, i);
const int64_t t1 = parakeet_full_get_segment_t1(pctx, i);
const int n_tokens = parakeet_full_n_tokens(pctx, i);
fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text);
fprintf(stderr, "Tokens [%d]:\n", n_tokens);
for (int j = 0; j < n_tokens; j++) {
parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j);
const char * token_str = parakeet_token_to_str(pctx, token_data.id);
fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n",
j,
token_data.id,
token_data.frame_index,
token_data.duration_idx,
token_data.duration_value,
token_data.p,
token_data.plog,
(long long)token_data.t0,
(long long)token_data.t1,
token_data.is_word_start ? "true": "false",
token_str);
}
}
}
}
parakeet_free(pctx);
return 0;
}