whisper : fix max_tokens skipping remaining audio (#3798)
* whisper: fix max_tokens skipping remaining audio * add PR reference comment as suggested Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix(ci): enable artifact overwrite
This commit is contained in:
parent
338cce1e58
commit
f08258abd7
|
|
@ -662,6 +662,7 @@ jobs:
|
|||
with:
|
||||
name: ggml_${{ matrix.arch }}.dll
|
||||
path: build/bin/${{ matrix.build }}/ggml.dll
|
||||
overwrite: true
|
||||
|
||||
- name: Upload ggml base dll
|
||||
uses: actions/upload-artifact@v6
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package whisper_test
|
|||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
|
|
@ -92,6 +93,53 @@ func TestProcess(t *testing.T) {
|
|||
assert.NoError(err)
|
||||
}
|
||||
|
||||
func TestProcessMaxTokensPerSegment(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
assert.Equal(uint16(1), dec.NumChans)
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
|
||||
context.SetMaxTokensPerSegment(5)
|
||||
|
||||
err = context.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
var text strings.Builder
|
||||
nSegments := 0
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
nSegments++
|
||||
text.WriteString(segment.Text)
|
||||
}
|
||||
|
||||
assert.Greater(nSegments, 1)
|
||||
assert.Contains(text.String(), "country")
|
||||
}
|
||||
|
||||
func TestDetectedLanguage(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
|
|
|
|||
|
|
@ -6216,6 +6216,13 @@ static void whisper_process_logits(
|
|||
}
|
||||
}
|
||||
|
||||
// ref: https://github.com/ggml-org/whisper.cpp/pull/3798
|
||||
if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) {
|
||||
for (int i = 0; i < vocab.token_eot; ++i) {
|
||||
logits[i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// suppress sot and nosp tokens
|
||||
logits[vocab.token_sot] = -INFINITY;
|
||||
logits[vocab.token_nosp] = -INFINITY;
|
||||
|
|
@ -7725,7 +7732,12 @@ int whisper_full_with_state(
|
|||
}
|
||||
|
||||
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
|
||||
const bool max_tokens_timestamp_ending = params.max_tokens > 0 &&
|
||||
!params.single_segment &&
|
||||
tokens_cur.size() > (size_t) params.max_tokens;
|
||||
|
||||
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
|
||||
!max_tokens_timestamp_ending &&
|
||||
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
|
||||
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
|
||||
if (single_timestamp_ending) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue