whisper : fix max_tokens skipping remaining audio (#3798)

* whisper: fix max_tokens skipping remaining audio

* add PR reference comment as suggested

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix(ci): enable artifact overwrite
This commit is contained in:
annaeina 2026-05-13 13:32:00 +08:00 committed by GitHub
parent 338cce1e58
commit f08258abd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 0 deletions

View File

@ -662,6 +662,7 @@ jobs:
with:
name: ggml_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml.dll
overwrite: true
- name: Upload ggml base dll
uses: actions/upload-artifact@v6

View File

@ -2,6 +2,7 @@ package whisper_test
import (
"os"
"strings"
"testing"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
@ -92,6 +93,53 @@ func TestProcess(t *testing.T) {
assert.NoError(err)
}
func TestProcessMaxTokensPerSegment(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
assert.Equal(uint16(1), dec.NumChans)
data := buf.AsFloat32Buffer().Data
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
context, err := model.NewContext()
assert.NoError(err)
context.SetMaxTokensPerSegment(5)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
var text strings.Builder
nSegments := 0
for {
segment, err := context.NextSegment()
if err != nil {
break
}
nSegments++
text.WriteString(segment.Text)
}
assert.Greater(nSegments, 1)
assert.Contains(text.String(), "country")
}
func TestDetectedLanguage(t *testing.T) {
assert := assert.New(t)

View File

@ -6216,6 +6216,13 @@ static void whisper_process_logits(
}
}
// ref: https://github.com/ggml-org/whisper.cpp/pull/3798
if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) {
for (int i = 0; i < vocab.token_eot; ++i) {
logits[i] = -INFINITY;
}
}
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_nosp] = -INFINITY;
@ -7725,7 +7732,12 @@ int whisper_full_with_state(
}
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
const bool max_tokens_timestamp_ending = params.max_tokens > 0 &&
!params.single_segment &&
tokens_cur.size() > (size_t) params.max_tokens;
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
!max_tokens_timestamp_ending &&
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
if (single_timestamp_ending) {