From f08258abd74b995bb95d8005103f72f1afd66a8a Mon Sep 17 00:00:00 2001 From: annaeina <2846698728@qq.com> Date: Wed, 13 May 2026 13:32:00 +0800 Subject: [PATCH] whisper : fix max_tokens skipping remaining audio (#3798) * whisper: fix max_tokens skipping remaining audio * add PR reference comment as suggested Co-authored-by: Georgi Gerganov * fix(ci): enable artifact overwrite --- .github/workflows/build.yml | 1 + bindings/go/pkg/whisper/context_test.go | 48 +++++++++++++++++++++++++ src/whisper.cpp | 12 +++++++ 3 files changed, 61 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb115b22a..be3f78a3f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -662,6 +662,7 @@ jobs: with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true - name: Upload ggml base dll uses: actions/upload-artifact@v6 diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index e98a4c2b8..79f6a5930 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -2,6 +2,7 @@ package whisper_test import ( "os" + "strings" "testing" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" @@ -92,6 +93,53 @@ func TestProcess(t *testing.T) { assert.NoError(err) } +func TestProcessMaxTokensPerSegment(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Decode the WAV file - load the full buffer + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + assert.NoError(err) + assert.Equal(uint16(1), dec.NumChans) + + data := buf.AsFloat32Buffer().Data + + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + context, err := model.NewContext() + assert.NoError(err) + + context.SetMaxTokensPerSegment(5) + + err = context.Process(data, nil, nil, nil) + assert.NoError(err) + + var text strings.Builder + nSegments := 0 + for { + segment, err := context.NextSegment() + if err != nil { + break + } + nSegments++ + text.WriteString(segment.Text) + } + + assert.Greater(nSegments, 1) + assert.Contains(text.String(), "country") +} + func TestDetectedLanguage(t *testing.T) { assert := assert.New(t) diff --git a/src/whisper.cpp b/src/whisper.cpp index 6176d21f5..210ca597f 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6216,6 +6216,13 @@ static void whisper_process_logits( } } + // ref: https://github.com/ggml-org/whisper.cpp/pull/3798 + if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; @@ -7725,7 +7732,12 @@ int whisper_full_with_state( } // ref: https://github.com/ggml-org/whisper.cpp/pull/2629 + const bool max_tokens_timestamp_ending = params.max_tokens > 0 && + !params.single_segment && + tokens_cur.size() > (size_t) params.max_tokens; + const bool single_timestamp_ending = tokens_cur.size() > 1 && + !max_tokens_timestamp_ending && tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); if (single_timestamp_ending) {