#include "ggml.h" #include "common.h" #include "whisper.h" #include "grammar-parser.h" #include #include #include #include #include #include #include #include #include constexpr int N_THREAD = 8; std::vector g_contexts(4, nullptr); std::mutex g_mutex; std::thread g_worker; std::atomic g_running(false); std::string g_status = ""; std::string g_status_forced = ""; std::string g_transcribed = ""; std::vector g_pcmf32; void command_set_status(const std::string & status) { std::lock_guard lock(g_mutex); g_status = status; } std::string command_transcribe( whisper_context * ctx, const whisper_full_params & wparams, const std::vector & pcmf32, float & logprob_min, float & logprob_sum, int & n_tokens, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); logprob_min = 0.0f; logprob_sum = 0.0f; n_tokens = 0; t_ms = 0; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } std::string result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); result += text; const int n = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n; ++j) { const auto token = whisper_full_get_token_data(ctx, i, j); if(token.plog > 0.0f) exit(0); // todo: check for emscripten logprob_min = std::min(logprob_min, token.plog); logprob_sum += token.plog; ++n_tokens; } } const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); return result; } void command_get_audio(int ms, int sample_rate, std::vector & audio) { const int64_t n_samples = (ms * sample_rate) / 1000; int64_t n_take = 0; if (n_samples > (int) g_pcmf32.size()) { n_take = g_pcmf32.size(); } else { n_take = n_samples; } audio.resize(n_take); std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin()); } static constexpr std::array positions = { "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", "a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", "a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", "a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4", "a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5", "a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6", "a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7", "a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8", }; static constexpr std::array pieceNames = { "pawn", "knight", "bishop", "rook", "queen", "king", }; class Board { public: struct Piece { enum Types { Pawn, Knight, Bishop, Rook, Queen, King, Taken, }; static_assert(pieceNames.size() == Piece::Taken, "Mismatch between piece names and types"); enum Colors { Black, White }; Types type; Colors color; int pos; }; std::array blackPieces = {{ {Piece::Pawn, Piece::Black, 48 }, {Piece::Pawn, Piece::Black, 49 }, {Piece::Pawn, Piece::Black, 50 }, {Piece::Pawn, Piece::Black, 51 }, {Piece::Pawn, Piece::Black, 52 }, {Piece::Pawn, Piece::Black, 53 }, {Piece::Pawn, Piece::Black, 54 }, {Piece::Pawn, Piece::Black, 55 }, {Piece::Rook, Piece::Black, 56 }, {Piece::Knight, Piece::Black, 57 }, {Piece::Bishop, Piece::Black, 58 }, {Piece::Queen, Piece::Black, 59 }, {Piece::King, Piece::Black, 60 }, {Piece::Bishop, Piece::Black, 61 }, {Piece::Knight, Piece::Black, 62 }, {Piece::Rook, Piece::Black, 63 }, }}; std::array whitePieces = {{ {Piece::Pawn, Piece::White, 8 }, {Piece::Pawn, Piece::White, 9 }, {Piece::Pawn, Piece::White, 10 }, {Piece::Pawn, Piece::White, 11 }, {Piece::Pawn, Piece::White, 12 }, {Piece::Pawn, Piece::White, 13 }, {Piece::Pawn, Piece::White, 14 }, {Piece::Pawn, Piece::White, 15 }, {Piece::Rook, Piece::White, 0 }, {Piece::Knight, Piece::White, 1 }, {Piece::Bishop, Piece::White, 2 }, {Piece::Queen, Piece::White, 3 }, {Piece::King, Piece::White, 4 }, {Piece::Bishop, Piece::White, 5 }, {Piece::Knight, Piece::White, 6 }, {Piece::Rook, Piece::White, 7 }, }}; using BB = std::array; BB board = {{ &whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15], &whitePieces[ 0], &whitePieces[ 1], &whitePieces[ 2], &whitePieces[ 3], &whitePieces[ 4], &whitePieces[ 5], &whitePieces[ 6], &whitePieces[ 7], nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, &blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7], &blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15], }}; bool checkNext(const Piece& piece, int pos, bool kingCheck = false) { if (piece.type == Piece::Taken) return false; if (piece.pos == pos) return false; int i = piece.pos / 8; int j = piece.pos - i * 8; int ii = pos / 8; int jj = pos - ii * 8; if (piece.type == Piece::Pawn) { if (piece.color == Piece::White) { int direction = piece.color == Piece::White ? 1 : -1; if (j == jj) { if (i == ii - direction) return board[pos] == nullptr; if (i == ii - direction * 2) return board[(ii - direction) * 8 + jj] == nullptr && board[pos] == nullptr; } else if (j + 1 == jj || j - 1 == jj) { if (i == ii - direction) return board[pos] != nullptr && board[pos]->color != piece.color; } } return false; } if (piece.type == Piece::Knight) { int di = std::abs(i - ii); int dj = std::abs(j - jj); if ((di == 2 && dj == 1) || (di == 1 && dj == 2)) return board[pos] == nullptr || board[pos]->color != piece.color; return false; } if (piece.type == Piece::Bishop) { if (i - j == ii - jj) { int direction = i < ii ? 1 : -1; i += direction; j += direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; j += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } if (i + j == ii + jj) { int direction = i < ii ? 1 : -1; i += direction; j -= direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; j -= direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } return false; } if (piece.type == Piece::Rook) { if (i == ii) { int direction = j < jj ? 1 : -1; j += direction; while (j != jj) { if (board[i * 8 + j]) return false; j += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } if (j == jj) { int direction = i < ii ? 1 : -1; i += direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } return false; } if (piece.type == Piece::Queen) { if (i - j == ii - jj) { int direction = i < ii ? 1 : -1; i += direction; j += direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; j += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } if (i + j == ii + jj) { int direction = i < ii ? 1 : -1; i += direction; j -= direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; j -= direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } if (i == ii) { int direction = j < jj ? 1 : -1; j += direction; while (j != jj) { if (board[i * 8 + j]) return false; j += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } if (j == jj) { int direction = i < ii ? 1 : -1; i += direction; while (i != ii) { if (board[i * 8 + j]) return false; i += direction; } return board[pos] == nullptr || board[pos]->color != piece.color; } return false; } if (piece.type == Piece::King) { if (std::abs(i - ii) < 2 && std::abs(j - jj) < 2) { auto& pieces = piece.color == Piece::White ? whitePieces : blackPieces; for (auto& enemyPiece: pieces) { if (!kingCheck && piece.type != Piece::Taken && checkNext(enemyPiece, pos, true)) return false; } return board[pos] == nullptr || board[pos]->color != piece.color; } } return false; } int moveCount = 0; void addMoves(const std::string& t) { std::vector moves; size_t cur = 0; size_t last = 0; while (cur != std::string::npos) { cur = t.find(',', last); moves.push_back(t.substr(last, cur)); last = cur + 1; } // fixme: lookup depends on grammar int count = moveCount; for (auto& move : moves) { fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", move.c_str(), "\033[0m"); if (move.empty()) continue; auto pieceIndex = 0u; for (; pieceIndex < pieceNames.size(); ++pieceIndex) { if (std::string::npos != move.find(pieceNames[pieceIndex])) break; } auto posIndex = 0u; for (; posIndex < positions.size(); ++posIndex) { if (std::string::npos != move.find(positions[posIndex])) break; } if (pieceIndex >= pieceNames.size() || posIndex >= positions.size()) continue; auto& pieces = count % 2 ? blackPieces : whitePieces; auto type = Piece::Types(pieceIndex); pieceIndex = 0; for (; pieceIndex < pieces.size(); ++pieceIndex) { if (pieces[pieceIndex].type == type && checkNext(pieces[pieceIndex], posIndex)) break; } if (pieceIndex < pieces.size()) { m_pendingMoves.push_back({&pieces[pieceIndex], posIndex}); } } } std::string stringifyMoves() { std::string res; for (auto& m : m_pendingMoves) { res.append(positions[m.first->pos]); res.push_back('-'); res.append(positions[m.second]); res.push_back(' '); } if (!res.empty()) res.pop_back(); return res; } void commitMoves() { for (auto& m : m_pendingMoves) { if (board[m.second]) board[m.second]->type = Piece::Taken; board[m.first->pos] = nullptr; m.first->pos = m.second; board[m.second] = m.first; } m_pendingMoves.clear(); } std::vector> m_pendingMoves; }; Board g_board; void command_main(size_t index) { command_set_status("loading data ..."); struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); wparams.offset_ms = 0; wparams.translate = false; wparams.no_context = true; wparams.single_segment = true; wparams.print_realtime = false; wparams.print_progress = false; wparams.print_timestamps = true; wparams.print_special = false; wparams.max_tokens = 32; // wparams.audio_ctx = 768; // partial encoder context for better performance wparams.temperature = 0.4f; wparams.temperature_inc = 1.0f; wparams.greedy.best_of = 1; wparams.beam_search.beam_size = 5; wparams.language = "en"; printf("command: using %d threads\n", wparams.n_threads); bool have_prompt = false; bool ask_prompt = true; bool print_energy = false; float logprob_min0 = 0.0f; float logprob_min = 0.0f; float logprob_sum0 = 0.0f; float logprob_sum = 0.0f; int n_tokens0 = 0; int n_tokens = 0; std::vector pcmf32_cur; std::vector pcmf32_prompt; // todo: grammar to be based on js input const std::string k_prompt = "rook to b4, f3,"; wparams.initial_prompt = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; auto grammar_parsed = grammar_parser::parse( "\n" "root ::= init move move? move? \".\"\n" "prompt ::= init \".\"\n" "\n" "# leading space is very important!\n" "init ::= \" rook to b4, f3\"\n" "\n" "move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" "\n" "piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" "king ::= \"king\"\n" "pawn ::= \"pawn\"\n" "\n" ); auto grammar_rules = grammar_parsed.c_rules(); if (grammar_parsed.rules.empty()) { fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__); } else { wparams.grammar_rules = grammar_rules.data(); wparams.n_grammar_rules = grammar_rules.size(); wparams.grammar_penalty = 100.0; } // whisper context auto & ctx = g_contexts[index]; const int32_t vad_ms = 2000; const int32_t prompt_ms = 5000; const int32_t command_ms = 4000; const float vad_thold = 0.1f; const float freq_thold = -1.0f; while (g_running) { // delay std::this_thread::sleep_for(std::chrono::milliseconds(100)); if (ask_prompt) { fprintf(stdout, "\n"); fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); fprintf(stdout, "\n"); { char txt[1024]; snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); command_set_status(txt); } ask_prompt = false; } int64_t t_ms = 0; { command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); command_set_status("Speech detected! Processing ..."); if (!have_prompt) { command_get_audio(prompt_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt"); const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); const float sim = similarity(txt, k_prompt); if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); ask_prompt = true; } else { fprintf(stdout, "\n"); fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); fprintf(stdout, "\n"); { char txt[1024]; snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); command_set_status(txt); } // save the audio for the prompt pcmf32_prompt = pcmf32_cur; have_prompt = true; } } else { command_get_audio(command_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); // prepend 3 second of silence pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); const float p = 100.0f * std::exp(logprob_min); fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); // find the prompt in the text float best_sim = 0.0f; size_t best_len = 0; for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { if (n >= int(txt.size())) { break; } const auto prompt = txt.substr(0, n); const float sim = similarity(prompt, k_prompt); //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); if (sim > best_sim) { best_sim = sim; best_len = n; } } fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); std::string command = ::trim(txt.substr(best_len)); fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "\n"); { char txt[1024]; snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms); command_set_status(txt); } { std::lock_guard lock(g_mutex); if (!command.empty()) { g_board.addMoves(command); } g_transcribed = std::move(command); } } g_pcmf32.clear(); } } } if (index < g_contexts.size()) { whisper_free(g_contexts[index]); g_contexts[index] = nullptr; } } EMSCRIPTEN_BINDINGS(command) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { g_running = true; if (g_worker.joinable()) { g_worker.join(); } g_worker = std::thread([i]() { command_main(i); }); return i + 1; } else { return (size_t) 0; } } } return (size_t) 0; })); emscripten::function("free", emscripten::optional_override([](size_t index) { if (g_running) { g_running = false; } })); emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) { --index; if (index >= g_contexts.size()) { return -1; } if (g_contexts[index] == nullptr) { return -2; } { std::lock_guard lock(g_mutex); const int n = audio["length"].as(); emscripten::val heap = emscripten::val::module_property("HEAPU8"); emscripten::val memory = heap["buffer"]; g_pcmf32.resize(n); emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast(g_pcmf32.data()), n); memoryView.call("set", audio); } return 0; })); emscripten::function("get_transcribed", emscripten::optional_override([]() { std::string transcribed; { std::lock_guard lock(g_mutex); transcribed = std::move(g_transcribed); } return transcribed; })); emscripten::function("get_moves", emscripten::optional_override([]() { std::string moves; { std::lock_guard lock(g_mutex); moves = g_board.stringifyMoves(); fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m"); } return moves; })); emscripten::function("commit_moves", emscripten::optional_override([]() { { std::lock_guard lock(g_mutex); g_board.commitMoves(); } })); emscripten::function("discard_moves", emscripten::optional_override([]() { { std::lock_guard lock(g_mutex); g_board.m_pendingMoves.clear(); } })); emscripten::function("get_status", emscripten::optional_override([]() { std::string status; { std::lock_guard lock(g_mutex); status = g_status_forced.empty() ? g_status : g_status_forced; } return status; })); emscripten::function("set_status", emscripten::optional_override([](const std::string & status) { { std::lock_guard lock(g_mutex); g_status_forced = status; } })); }