From 54d168db67f63344ee3cd3bfc0137960d5f4d3fa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Sep 2023 17:58:54 +0300 Subject: [PATCH] command : grammar-related improvements - option to read grammar from file - add sample grammars for colors and chess moves - fine-tune the performance further --- examples/command/command.cpp | 77 ++++++++++++++++++++++++------------ examples/grammar-parser.cpp | 2 +- examples/grammar-parser.h | 2 +- grammars/chess.gbnf | 27 +++++++++++++ grammars/colors.gbnf | 24 +++++++++++ whisper.cpp | 65 ++++++++++++++++++------------ 6 files changed, 144 insertions(+), 53 deletions(-) create mode 100644 grammars/chess.gbnf create mode 100644 grammars/colors.gbnf diff --git a/examples/command/command.cpp b/examples/command/command.cpp index f33f8e15..dfffbda7 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -22,6 +22,11 @@ #include #include +bool file_exists(const std::string & fname) { + std::ifstream f(fname.c_str()); + return f.good(); +} + // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -36,6 +41,8 @@ struct whisper_params { float grammar_penalty = 100.0f; + grammar_parser::parse_state grammar_parsed; + bool speed_up = false; bool translate = false; bool print_special = false; @@ -117,15 +124,18 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); } -std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { +std::string transcribe( + whisper_context * ctx, + const whisper_params & params, + const std::vector & pcmf32, + const std::string & grammar_rule, + float & prob, + int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); prob = 0.0f; t_ms = 0; - grammar_parser::parse_state parsed_grammar; - std::vector grammar_rules; - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; @@ -140,17 +150,20 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.n_threads = params.n_threads; // disable fallback - seems not useful for command recognition - wparams.temperature_inc = 0.0f; + wparams.temperature_inc = 0.00f; - wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - grammar_rules = parsed_grammar.c_rules(); + //wparams.initial_prompt = params.prompt.data(); + + const auto & grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (!params.grammar_parsed.rules.empty()) { wparams.grammar_rules = grammar_rules.data(); wparams.n_grammar_rules = grammar_rules.size(); - wparams.i_start_rule = parsed_grammar.symbol_ids.at("root"); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule); wparams.grammar_penalty = params.grammar_penalty; } @@ -270,7 +283,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const fprintf(stderr, " ]\n"); } - std::string k_prompt = "select one from the available words: "; + std::string k_prompt = "select one from the available words: "; for (int i = 0; i < (int) allowed_commands.size(); ++i) { if (i > 0) { k_prompt += ", "; @@ -476,7 +489,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // detect the commands audio.get(params.command_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", prob, t_ms)); const auto words = get_words(txt); @@ -523,9 +536,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au std::vector pcmf32_cur; std::vector pcmf32_prompt; - //const std::string k_prompt = "Ok Whisper, start listening for commands."; - //const std::string k_prompt = "Начало."; - const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди."; + std::string k_prompt = "Ok Whisper, start listening for commands."; + if (!params.prompt.empty()) { + k_prompt = params.prompt; + } fprintf(stderr, "\n"); fprintf(stderr, "%s: general-purpose mode\n", __func__); @@ -558,7 +572,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // wait for activation phrase audio.get(params.prompt_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob0, t_ms)); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); @@ -581,13 +595,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // we have heard the activation phrase, now detect the commands audio.get(params.command_ms, pcmf32_cur); + //printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE); + //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE); + + // prepend 3 second of silence + pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f); + // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - // append 1 second of silence - pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f); - - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob, t_ms)); prob = 100.0f*(prob - prob0); @@ -688,13 +705,23 @@ int main(int argc, char ** argv) { int ret_val = 0; if (!params.grammar.empty()) { - auto parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + auto & grammar = params.grammar_parsed; + if (file_exists(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { + if (grammar.rules.empty()) { ret_val = 1; } else { fprintf(stderr, "%s: grammar:\n", __func__); - grammar_parser::print_grammar(stderr, parsed_grammar); + grammar_parser::print_grammar(stderr, grammar); fprintf(stderr, "\n"); } } @@ -702,7 +729,7 @@ int main(int argc, char ** argv) { if (ret_val == 0) { if (!params.commands.empty()) { ret_val = process_command_list(ctx, audio, params); - } else if (!params.prompt.empty()) { + } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) { ret_val = always_prompt_transcription(ctx, audio, params); } else { ret_val = process_general_transcription(ctx, audio, params); diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index b5b607fa..2daaaef4 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -413,7 +413,7 @@ namespace grammar_parser { } } - std::vector parse_state::c_rules() { + std::vector parse_state::c_rules() const{ std::vector ret; for (const auto & rule : rules) { ret.push_back(rule.data()); diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index ef0ec441..47d019c3 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -21,7 +21,7 @@ namespace grammar_parser { std::map symbol_ids; std::vector> rules; - std::vector c_rules(); + std::vector c_rules() const; }; parse_state parse(const char * src); diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 00000000..122ce123 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,27 @@ +# - bishop to c3 +# - rook to d4 +# - knight to e5 +# - d4 d5 knight to c3 +# - c3 queen to d4 king b1 +# - pawn to a1 bishop to b2 knight to c3 +# +# initial prompt: +# +# "pawn to a1, bishop to b2, knight to c3, rook to d4, queen to e5, king to f6," +# +# example: +# +# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6" +# + +root ::= init (move? move? move? ".") +prompt ::= init "." + +# leading space is very important! +init ::= " pawn knight king a1 f5 h6" + +move ::= " " ((piece | pawn | king) " " "to "?)? [a-h] [1-8] + +piece ::= "bishop" | "rook" | "knight" | "queen" +king ::= "king" +pawn ::= "pawn" diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf new file mode 100644 index 00000000..f4a4930c --- /dev/null +++ b/grammars/colors.gbnf @@ -0,0 +1,24 @@ +# - red +# - green +# - blue +# - red green +# - red blue +# - green red +# - green blue green +# +# initial prompt: +# +# "red green blue" +# +# example: +# +# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue" +# + +root ::= init color (color)? (color)? "." +prompt ::= init "." + +# leading space is very important! +init ::= " red green blue" + +color ::= " " ("red" | "green" | "blue") diff --git a/whisper.cpp b/whisper.cpp index 5e3b86a8..c357994b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3865,28 +3865,29 @@ static struct whisper_grammar whisper_grammar_init( static void whisper_suppress_invalid_grammar( whisper_context & ctx, const whisper_full_params & params, - std::vector & logprobs, + std::vector & logits, const whisper_grammar & grammar) { if (grammar.rules.empty() || grammar.stacks.empty()) { return; } - // bool allow_eot = false; - // for (const auto & stack : grammar.stacks) { - // if (stack.empty()) { - // allow_eot = true; - // break; - // } - // } + bool allow_eot = false; + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + allow_eot = true; + break; + } + } + + const whisper_token eot = whisper_token_eot(&ctx); std::vector, whisper_partial_utf8>> candidates_decoded; std::vector candidates_grammar; - size_t size = logprobs.size(); - for (whisper_token id = 0; id < (int) size; ++id) { + for (whisper_token id = 0; id < eot; ++id) { const std::string & text = ctx.vocab.id_to_token[id]; - if (!text.empty() && text.rfind("[_", 0) != 0) { + if (!text.empty()) { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } @@ -3895,14 +3896,12 @@ static void whisper_suppress_invalid_grammar( const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - logprobs[reject.id] -= params.grammar_penalty; + logits[reject.id] -= params.grammar_penalty; } - // when the grammar does not allow any continuation, we don't want to penalize the EOT token - // TODO: is there are better way to do this? - printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); - if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { - logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; + // when the grammar allows a continuation, we penalize the end-of-text token + if (!allow_eot) { + logits[eot] -= params.grammar_penalty; } //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } @@ -3912,7 +3911,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar return; } - fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); const std::string & text = ctx.vocab.id_to_token[token]; @@ -4308,14 +4307,28 @@ static void whisper_process_logits( logits[i] = -INFINITY; logprobs[i] = -INFINITY; } - } else { - //printf("sampling text\n"); - for (int i = vocab.token_beg; i < n_logits; ++i) { - logits[i] = -INFINITY; - logprobs[i] = -INFINITY; - } + } else if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } } } } @@ -4331,7 +4344,7 @@ static void whisper_process_logits( } } -#if 1 +#if 0 // print first 100 logits - token string : logit //for (int i = 0; i < 10; i++) { // const auto token = vocab.id_to_token.at(i);