This commit is contained in:
Kirill 2026-04-23 05:22:38 +00:00 committed by GitHub
commit d7698ab1fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 3305 additions and 601 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@
*.d
.cache/
.coreml/
pkg/
.test/
.venv/
.vs/

View File

@ -1,2 +1,4 @@
build
models
samples/a13.wav
samples/benchmark_out.wav

View File

@ -35,7 +35,7 @@ whisper: mkdir
-DBUILD_SHARED_LIBS=OFF
cmake --build ../../${BUILD_DIR} --target whisper
test: model-small whisper modtidy
test: model-tiny whisper modtidy
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
@ -46,8 +46,15 @@ endif
examples: $(EXAMPLES_DIR)
model-small: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
benchmark: model-tiny whisper modtidy
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -bench='BenchmarkContextProcessCPU$$' -benchtime=1x -benchmem -run '^$$' ./pkg/whisper/...
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -bench='BenchmarkContextProcessCPU$$' -benchtime=1x -benchmem -run '^$$' ./pkg/whisper/...
endif
model-tiny: mkdir examples/go-model-download
@${BUILD_DIR}/go-model-download -out models ggml-tiny.en.bin
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)

View File

@ -7,8 +7,12 @@ This package provides Go bindings for whisper.cpp. They have been tested on:
* Fedora Linux on x86_64
The "low level" bindings are in the `bindings/go` directory and there is a more
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
is as follows:
Go-style package in the `bindings/go/pkg/whisper` directory.
Legacy stateless example (single worker). For the recommended stateful API and
concurrency-safe usage, see "New high-level API" below. Note: `Model.NewContext()`
returns a stateless context for backward compatibility and is not safe for parallel
`Process` calls (may return `ErrStatelessBusy`).
```go
import (
@ -100,6 +104,123 @@ Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312)
## New high-level API (stateful and stateless contexts)
The `pkg/whisper` package now exposes two context kinds:
- StatefulContext: recommended for concurrency. Each context owns its own whisper_state.
- StatelessContext: shares the model context. Simpler, but not suitable for parallel `Process` calls.
### Quick start: stateful context (recommended)
```go
package main
import (
"fmt"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func main() {
// Load model
model, err := whisper.NewModelContext("./models/ggml-small.en.bin")
if err != nil {
panic(err)
}
defer model.Close()
// Configure parameters (optional: provide a config func)
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) {
p.SetThreads(4)
p.SetLanguage("en") // or "auto"
p.SetTranslate(false)
})
if err != nil {
panic(err)
}
// Create stateful context (safe for running in parallel goroutines)
ctx, err := whisper.NewStatefulContext(model, params)
if err != nil {
panic(err)
}
defer ctx.Close()
// Your 16-bit mono PCM at 16kHz as float32 samples
var samples []float32
// Process. Callbacks are optional.
if err := ctx.Process(samples, nil, nil, nil); err != nil {
panic(err)
}
// Read segments
for {
seg, err := ctx.NextSegment()
if err != nil {
break
}
fmt.Printf("[%v -> %v] %s\n", seg.Start, seg.End, seg.Text)
}
}
```
### Quick start: stateless context (single worker)
```go
// Load model as above
model, _ := whisper.NewModelContext("./models/ggml-small.en.bin")
defer model.Close()
params, _ := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
ctx, _ := whisper.NewStatelessContext(model, params)
defer ctx.Close()
if err := ctx.Process(samples, nil, nil, nil); err != nil { panic(err) }
for {
seg, err := ctx.NextSegment()
if err != nil { break }
fmt.Println(seg.Text)
}
```
### Deprecations and migration notes
- The `Context` interface setters are deprecated (SetThreads, SetLanguage, etc.). Use `Parameters` via `NewParameters` and pass it when creating a context.
- `Model.NewContext()` remains for backward compatibility and returns a stateless context by default. Prefer `NewStatefulContext` for concurrency.
- Stateless contexts share the model context. A concurrency gate prevents overlapping `Process` calls and will return `ErrStatelessBusy` if another `Process` is in flight.
- For parallel processing, create one `StatefulContext` per goroutine.
## Benchmarks
Benchmarks live in `pkg/whisper` and compare CPU vs GPU, stateful vs stateless, threads, and callback modes.
### Prerequisites
- Model: `models/ggml-small.en.bin` (or your choice).
- Sample: `samples/jfk.wav`.
- Build the C libs once (also downloads a model for examples):
```bash
cd bindings/go
make examples
# optionally: ./build/go-model-download -out models
```
### Run benchmarks
```bash
cd bindings/go/pkg/whisper
make benchmark
```
### What the benchmarks measure
- Variants: device (cpu/gpu) x context kind (stateless/stateful) x threads {1,2,4, NumCPU} x callback mode (NoCallback, WithSegmentCallback).
- Standard Go benchmark outputs: ns/op, B/op, allocs/op. We also set bytes per op to sample bytes.
- Custom metric `ms_process`: wall time per `Process` iteration, reported via `b.ReportMetric`.
- When `printTimings` is enabled, model-level timings are printed for NoCallback runs using `model.PrintTimings()`.
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -18,9 +18,10 @@ import (
// CONSTANTS
const (
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
srcUrlTinydiarize = "https://huggingface.co/akashmjn/tinydiarize-whisper.cpp/resolve/main/"
srcExt = ".bin" // Filename extension
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
)
var (
@ -38,6 +39,7 @@ var (
"large-v2", "large-v2-q5_0", "large-v2-q8_0",
"large-v3", "large-v3-q5_0",
"large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0",
"small.en-tdrz",
}
)
@ -219,6 +221,12 @@ func URLForModel(model string) (string, error) {
model += srcExt
}
srcUrl := srcUrl
if strings.Contains(model, "tdrz") {
srcUrl = srcUrlTinydiarize
}
// Parse the base URL
url, err := url.Parse(srcUrl)
if err != nil {

View File

@ -3,13 +3,13 @@ module github.com/ggerganov/whisper.cpp/bindings/go
go 1.23
require (
github.com/go-audio/audio v1.0.0
github.com/go-audio/wav v1.1.0
github.com/stretchr/testify v1.9.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-audio/audio v1.0.0 // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

View File

@ -47,6 +47,16 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}
// Enable extra debug information
func (p *Params) SetDebugMode(v bool) {
p.debug_mode = toBool(v)
}
// Enable tinydiarize speaker turn detection
func (p *Params) SetDiarize(v bool) {
p.tdrz_enable = toBool(v)
}
// Voice Activity Detection (VAD)
func (p *Params) SetVAD(v bool) {
p.vad = toBool(v)
@ -179,6 +189,8 @@ func (p *Params) SetInitialPrompt(prompt string) {
p.initial_prompt = C.CString(prompt)
}
// SetCarryInitialPrompt if true, always prepend initial_prompt to every decode window
// (may reduce conditioning on previous text)
func (p *Params) SetCarryInitialPrompt(v bool) {
p.carry_initial_prompt = toBool(v)
}
@ -236,9 +248,6 @@ func (p *Params) String() string {
if p.token_timestamps {
str += " token_timestamps"
}
if p.carry_initial_prompt {
str += " carry_initial_prompt"
}
return str + ">"
}

View File

@ -0,0 +1,58 @@
package whisper
import (
"sync"
"sync/atomic"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
// Gate provides a simple acquire/release contract per key.
// The default implementation is a single-entry lock per key (limit=1).
type Gate interface {
// Acquire returns true if the key was acquired; false if already held
Acquire(key any) bool
// Release releases the key if currently held
Release(key any)
}
// singleFlightGate is a minimal lock with limit=1 per key
type singleFlightGate struct {
m sync.Map // key -> *int32 (0 available, 1 held)
}
func (g *singleFlightGate) Acquire(key any) bool {
ptr, _ := g.m.LoadOrStore(key, new(int32))
busy := ptr.(*int32)
return atomic.CompareAndSwapInt32(busy, 0, 1)
}
func (g *singleFlightGate) Release(key any) {
if v, ok := g.m.Load(key); ok {
atomic.StoreInt32(v.(*int32), 0)
}
}
var defaultGate Gate = &singleFlightGate{}
// SetGate allows applications to override the default gate (e.g., for custom policies)
// Passing nil resets to the default singleFlightGate.
func SetGate(g Gate) {
if g == nil {
defaultGate = &singleFlightGate{}
return
}
defaultGate = g
}
func gate() Gate { return defaultGate }
// modelKey derives a stable key per underlying model context for guarding stateless ops
func modelKey(model *ModelContext) *whisper.Context {
if model == nil || model.ctxAccessor() == nil {
return nil
}
ctx, _ := model.ctxAccessor().context()
return ctx
}

View File

@ -11,11 +11,21 @@ import (
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrUnableToLoadModel = errors.New("unable to load model")
// Deprecated: Use ErrModelClosed instead for checking the model is closed error
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
ErrModelClosed = errors.Join(errors.New("model has been closed"), ErrInternalAppError)
ErrStatelessBusy = errors.New("stateless context is busy; concurrent processing not supported")
// Private errors
errParametersRequired = errors.New("parameters are required")
errModelRequired = errors.New("model is required")
errUnableToCreateState = errors.New("unable to create state")
)
///////////////////////////////////////////////////////////////////////////////
@ -26,3 +36,10 @@ const SampleRate = whisper.SampleRate
// SampleBits is the number of bytes per sample.
const SampleBits = whisper.SampleBits
type SamplingStrategy uint32
const (
SAMPLING_GREEDY SamplingStrategy = SamplingStrategy(whisper.SAMPLING_GREEDY)
SAMPLING_BEAM_SEARCH SamplingStrategy = SamplingStrategy(whisper.SAMPLING_BEAM_SEARCH)
)

View File

@ -1,385 +0,0 @@
package whisper
import (
"fmt"
"io"
"runtime"
"strings"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type context struct {
n int
model *model
params whisper.Params
}
// Make sure context adheres to the interface
var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
// Return success
return context, nil
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Set the language to use for speech recognition.
func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if lang == "auto" {
context.params.SetLanguage(-1)
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
return err
}
// Return success
return nil
}
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language
func (context *context) Language() string {
id := context.params.Language()
if id == -1 {
return "auto"
}
return whisper.Whisper_lang_str(context.params.Language())
}
func (context *context) DetectedLanguage() string {
return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Voice Activity Detection (VAD)
func (context *context) SetVAD(v bool) {
context.params.SetVAD(v)
}
func (context *context) SetVADModelPath(path string) {
context.params.SetVADModelPath(path)
}
func (context *context) SetVADThreshold(t float32) {
context.params.SetVADThreshold(t)
}
func (context *context) SetVADMinSpeechMs(ms int) {
context.params.SetVADMinSpeechMs(ms)
}
func (context *context) SetVADMinSilenceMs(ms int) {
context.params.SetVADMinSilenceMs(ms)
}
func (context *context) SetVADMaxSpeechSec(s float32) {
context.params.SetVADMaxSpeechSec(s)
}
func (context *context) SetVADSpeechPadMs(ms int) {
context.params.SetVADSpeechPadMs(ms)
}
func (context *context) SetVADSamplesOverlap(sec float32) {
context.params.SetVADSamplesOverlap(sec)
}
func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetDuration(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set token timestamps flag
func (context *context) SetTokenTimestamps(b bool) {
context.params.SetTokenTimestamps(b)
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Set audio encoder context
func (context *context) SetAudioCtx(n uint) {
context.params.SetAudioCtx(int(n))
}
// Set maximum number of text context tokens to store
func (context *context) SetMaxContext(n int) {
context.params.SetMaxContext(n)
}
// Set Beam Size
func (context *context) SetBeamSize(n int) {
context.params.SetBeamSize(n)
}
// Set Entropy threshold
func (context *context) SetEntropyThold(t float32) {
context.params.SetEntropyThold(t)
}
// Set Temperature
func (context *context) SetTemperature(t float32) {
context.params.SetTemperature(t)
}
// Set the fallback temperature incrementation
// Pass -1.0 to disable this feature
func (context *context) SetTemperatureFallback(t float32) {
context.params.SetTemperatureFallback(t)
}
// Set initial prompt
func (context *context) SetInitialPrompt(prompt string) {
context.params.SetInitialPrompt(prompt)
}
// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()
}
// PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings()
}
// SystemInfo returns the system information
func (context *context) SystemInfo() string {
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
context.params.Threads(),
runtime.NumCPU(),
whisper.Whisper_print_system_info(),
)
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
if err != nil {
return nil, err
}
return langProbs, nil
}
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if callNewSegment != nil {
context.params.SetSingleSegment(true)
}
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}
// Reset n so that more Segments can be available within NextSegment call
context.n = 0
// Return success
return nil
}
// Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) {
if context.model.ctx == nil {
return Segment{}, ErrInternalAppError
}
if context.n >= context.model.ctx.Whisper_full_n_segments() {
return Segment{}, io.EOF
}
// Populate result
result := toSegment(context.model.ctx, context.n)
// Increment the cursor
context.n++
// Return success
return result, nil
}
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
func toSegment(ctx *whisper.Context, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n),
}
}
func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i),
P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
}
}
return result
}

View File

@ -0,0 +1,285 @@
package whisper_test
import (
"fmt"
"io"
"math"
"os"
"runtime"
"testing"
"time"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/audio"
wav "github.com/go-audio/wav"
)
func processAndExtractSegmentsSequentially(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) {
if err := ctx.Process(samples, nil, nil, nil); err != nil {
return nil, err
}
var segments []whisper.Segment
for {
seg, err := ctx.NextSegment()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
segments = append(segments, seg)
}
return segments, nil
}
func processAndExtractSegmentsWithCallback(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) {
segments := make([]whisper.Segment, 0)
if err := ctx.Process(samples, nil, func(seg whisper.Segment) {
segments = append(segments, seg)
}, nil); err != nil {
return nil, err
}
return segments, nil
}
// benchProcessVariants runs the common benchmark matrix across context kinds,
// thread sets, and callback modes, for given samples. If singleIteration is true
// it runs only one iteration regardless of b.N. If printTimings is true,
// model timings and custom ms_process metric are reported for NoCallback runs.
func benchProcessVariants(
b *testing.B,
samples []float32,
singleIteration bool,
printTimings bool,
useGPU bool,
) {
threadSets := []uint{1, 2, 4, uint(runtime.NumCPU())}
device := "cpu"
if useGPU {
device = "gpu"
}
// Initialize model per device mode
mp := whisper.NewModelContextParams()
mp.SetUseGPU(useGPU)
model, err := whisper.NewModelContextWithParams(ModelPath, mp)
if err != nil {
b.Fatalf("load model (%s): %v", device, err)
}
defer func() { _ = model.Close() }()
// Context kinds: stateless and stateful
ctxKinds := []struct {
name string
new func() (whisper.Context, error)
}{
{
name: "stateless",
new: func() (whisper.Context, error) {
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) {})
if err != nil {
return nil, err
}
return whisper.NewStatelessContext(model, params)
},
},
{
name: "stateful",
new: func() (whisper.Context, error) {
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
if err != nil {
return nil, err
}
return whisper.NewStatefulContext(model, params)
},
},
}
for _, kind := range ctxKinds {
b.Run(device+"/"+kind.name, func(b *testing.B) {
for _, threads := range threadSets {
b.Run(fmt.Sprintf("threads=%d/NoCallback", threads), func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(samples) * 4))
ctx, err := kind.new()
if err != nil {
b.Fatalf("new %s context: %v", kind.name, err)
}
defer func() { _ = ctx.Close() }()
ctx.SetThreads(threads)
iters := b.N
if singleIteration {
iters = 1
}
b.ResetTimer()
for i := 0; i < iters; i++ {
model.ResetTimings()
start := time.Now()
segments, err := processAndExtractSegmentsSequentially(ctx, samples)
if err != nil {
b.Fatalf("process and extract segments sequentially: %v", err)
}
b.Logf("segments: %+v", segments)
elapsed := time.Since(start)
if printTimings {
model.PrintTimings()
}
b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process")
}
})
b.Run(fmt.Sprintf("threads=%d/WithSegmentCallback", threads), func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(samples) * 4))
ctx, err := kind.new()
if err != nil {
b.Fatalf("new %s context: %v", kind.name, err)
}
defer func() { _ = ctx.Close() }()
ctx.SetThreads(threads)
iters := b.N
if singleIteration {
iters = 1
}
b.ResetTimer()
for i := 0; i < iters; i++ {
start := time.Now()
model.ResetTimings()
// Passing a segment callback forces single-segment mode and exercises token extraction
segments, err := processAndExtractSegmentsWithCallback(ctx, samples)
if err != nil {
b.Fatalf("process with callback: %v", err)
}
b.Logf("segments: %+v", segments)
elapsed := time.Since(start)
if printTimings {
model.PrintTimings()
}
b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process")
}
})
}
})
}
}
// BenchmarkContextProcess runs the high-level Context.Process across
// different thread counts, with and without segment callbacks.
func BenchmarkContextProcessCPU(b *testing.B) {
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
b.Skipf("model not found: %s", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
b.Skipf("sample not found: %s", SamplePath)
}
// Load audio once (reuse helper)
data := helperLoadSample(b, SamplePath)
benchProcessVariants(b, data, false, true, false)
}
// BenchmarkContextProcessBig runs one single iteration over a big input
// (the short sample concatenated 10x) to simulate long audio processing.
// This is complementary to BenchmarkContextProcess which runs many iterations
// over the short sample.
func BenchmarkContextProcessBigCPU(b *testing.B) {
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
b.Skipf("model not found: %s", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
b.Skipf("sample not found: %s", SamplePath)
}
// Load audio once (reuse helper with meta)
data, sampleRate, numChans := helperLoadSampleWithMeta(b, SamplePath)
// Build big dataset: input concatenated 10x
bigData := make([]float32, len(data)*10)
for i := 0; i < 10; i++ {
copy(bigData[i*len(data):(i+1)*len(data)], data)
}
// Write the big dataset to a wav file for inspection
outPath := "../../samples/benchmark_out.wav"
fout, err := os.Create(outPath)
if err != nil {
b.Fatalf("create output wav: %v", err)
}
enc := wav.NewEncoder(fout, sampleRate, 16, numChans, 1)
intBuf := &audio.IntBuffer{
Format: &audio.Format{NumChannels: numChans, SampleRate: sampleRate},
SourceBitDepth: 16,
Data: make([]int, len(bigData)),
}
for i, s := range bigData {
v := int(math.Round(float64(s) * 32767.0))
if v > 32767 {
v = 32767
} else if v < -32768 {
v = -32768
}
intBuf.Data[i] = v
}
if err := enc.Write(intBuf); err != nil {
_ = fout.Close()
b.Fatalf("encode wav: %v", err)
}
if err := enc.Close(); err != nil {
_ = fout.Close()
b.Fatalf("close encoder: %v", err)
}
_ = fout.Close()
benchProcessVariants(b, bigData, true, true, false)
}
// GPU variants reuse model-level GPU enablement via model params
func BenchmarkContextProcessGPU(b *testing.B) {
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
b.Skipf("model not found: %s", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
b.Skipf("sample not found: %s", SamplePath)
}
data := helperLoadSample(b, SamplePath)
benchProcessVariants(b, data, false, true, true)
}
func BenchmarkContextProcessBigGPU(b *testing.B) {
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
b.Skipf("model not found: %s", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
b.Skipf("sample not found: %s", SamplePath)
}
data, _, _ := helperLoadSampleWithMeta(b, SamplePath)
bigData := make([]float32, len(data)*10)
for i := 0; i < 10; i++ {
copy(bigData[i*len(data):(i+1)*len(data)], data)
}
benchProcessVariants(b, bigData, true, true, true)
}

View File

@ -1,124 +1,526 @@
package whisper_test
import (
"io"
"os"
"testing"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
assert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetLanguage(t *testing.T) {
assert := assert.New(t)
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
context, err := model.NewContext()
assert.NoError(err)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
// This returns an error since
// the model 'models/ggml-small.en.bin'
// that is loaded is not multilingual
err = context.SetLanguage("en")
assert.Error(err)
// This returns an error since the small.en model is not multilingual
err := ctx.SetLanguage("en")
assert.Error(err)
})
}
}
func TestContextModelIsMultilingual(t *testing.T) {
assert := assert.New(t)
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
context, err := model.NewContext()
assert.NoError(err)
isMultilingual := context.IsMultilingual()
// This returns false since
// the model 'models/ggml-small.en.bin'
// that is loaded is not multilingual
assert.False(isMultilingual)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
assert.False(ctx.IsMultilingual())
})
}
}
func TestLanguage(t *testing.T) {
assert := assert.New(t)
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
context, err := model.NewContext()
assert.NoError(err)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
expectedLanguage := "en"
actualLanguage := ctx.Language()
assert.Equal(expectedLanguage, actualLanguage)
})
}
}
// This always returns en since
// the model 'models/ggml-small.en.bin'
// that is loaded is not multilingual
expectedLanguage := "en"
actualLanguage := context.Language()
assert.Equal(expectedLanguage, actualLanguage)
// Generic behavior: Language() and DetectedLanguage() match for both context types
func TestContext_Generic_LanguageAndDetectedLanguage(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
langBefore := ctx.Language()
assert.NoError(ctx.Process(data, nil, nil, nil))
detected := ctx.DetectedLanguage()
assert.Equal(langBefore, detected)
})
}
}
func TestProcess(t *testing.T) {
assert := assert.New(t)
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
assert.Equal(uint16(1), dec.NumChans)
data := helperLoadSample(t, SamplePath)
data := buf.AsFloat32Buffer().Data
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
context, err := model.NewContext()
assert.NoError(err)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
err := ctx.Process(data, nil, nil, nil)
assert.NoError(err)
})
}
}
func TestDetectedLanguage(t *testing.T) {
assert := assert.New(t)
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
assert.Equal(uint16(1), dec.NumChans)
data := helperLoadSample(t, SamplePath)
data := buf.AsFloat32Buffer().Data
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
context, err := model.NewContext()
assert.NoError(err)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
expectedLanguage := "en"
actualLanguage := context.DetectedLanguage()
assert.Equal(expectedLanguage, actualLanguage)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
err := ctx.Process(data, nil, nil, nil)
assert.NoError(err)
expectedLanguage := "en"
actualLanguage := ctx.DetectedLanguage()
assert.Equal(expectedLanguage, actualLanguage)
})
}
}
// TestContext_ConcurrentProcessing tests that multiple contexts can process concurrently
// without interfering with each other (validates the whisper_state isolation fix)
func TestContext_ConcurrentProcessing(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
err := ctx.Process(data, nil, nil, nil)
assert.NoError(err)
seg, err := ctx.NextSegment()
assert.NoError(err)
assert.NotEmpty(seg.Text)
})
}
}
// TestContext_Close tests that Context.Close() properly frees resources
// and allows context to be used even after it has been closed
func TestContext_Close(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
// Close the context
err := ctx.Close()
require.NoError(t, err)
// Try to use closed context - should return errors
err = ctx.Process([]float32{0.1, 0.2, 0.3}, nil, nil, nil)
require.ErrorIs(t, err, whisper.ErrModelClosed)
// TODO: remove this logic after deprecating the ErrInternalAppError
require.ErrorIs(t, err, whisper.ErrInternalAppError)
lang := ctx.DetectedLanguage()
require.Empty(t, lang)
_, err = ctx.NextSegment()
assert.ErrorIs(err, whisper.ErrModelClosed)
// TODO: remove this logic after deprecating the ErrInternalAppError
assert.ErrorIs(err, whisper.ErrInternalAppError)
// Multiple closes should be safe
err = ctx.Close()
require.NoError(t, err)
})
}
}
func Test_Close_Context_of_Closed_Model(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
t.Run("stateless", func(t *testing.T) {
model, err := whisper.NewModelContext(ModelPath)
assert.NoError(err)
defer func() { _ = model.Close() }()
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatelessContext(model, params)
assert.NoError(err)
require.NoError(t, model.Close())
require.NoError(t, ctx.Close())
})
t.Run("stateful", func(t *testing.T) {
model, err := whisper.NewModelContext(ModelPath)
assert.NoError(err)
defer func() { _ = model.Close() }()
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatefulContext(model, params)
assert.NoError(err)
require.NoError(t, model.Close())
require.NoError(t, ctx.Close())
})
}
func TestContext_VAD_And_Diarization_Params_DoNotPanic(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
model, err := whisper.NewModelContext(ModelPath)
assert.NoError(err)
defer func() { _ = model.Close() }()
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
assert.NoError(err)
assert.NotNil(params)
ctx, err := whisper.NewStatefulContext(model, params)
assert.NoError(err)
defer func() { _ = ctx.Close() }()
p := ctx.Params()
p.SetDiarize(true)
p.SetVAD(true)
p.SetVADThreshold(0.5)
p.SetVADMinSpeechMs(200)
p.SetVADMinSilenceMs(100)
p.SetVADMaxSpeechSec(10)
p.SetVADSpeechPadMs(30)
p.SetVADSamplesOverlap(0.02)
err = ctx.Process(data, nil, nil, nil)
assert.NoError(err)
}
func TestContext_SpeakerTurnNext_Field_Present(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
cases := []struct {
name string
new func(t *testing.T) (whisper.Context, func())
}{
{name: "stateless", new: helperNewStatelessContext},
{name: "stateful", new: helperNewStatefulContext},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cleanup := tc.new(t)
defer cleanup()
err := ctx.Process(data, nil, nil, nil)
assert.NoError(err)
seg, err := ctx.NextSegment()
assert.NoError(err)
t.Logf("SpeakerTurnNext: %v", seg.SpeakerTurnNext)
_ = seg.SpeakerTurnNext
})
}
}
// Ensure Process produces at least one segment for both stateless and stateful contexts
func TestContext_Process_ProducesSegments_BothKinds(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
// Stateless
stateless, cleanupS := helperNewStatelessContext(t)
defer cleanupS()
require.NoError(t, stateless.Process(data, nil, nil, nil))
var statelessCount int
for {
_, err := stateless.NextSegment()
if err == io.EOF {
break
}
require.NoError(t, err)
statelessCount++
}
assert.Greater(statelessCount, 0, "stateless should produce at least one segment")
// Stateful
stateful, cleanupSt := helperNewStatefulContext(t)
defer cleanupSt()
require.NoError(t, stateful.Process(data, nil, nil, nil))
var statefulCount int
for {
_, err := stateful.NextSegment()
if err == io.EOF {
break
}
require.NoError(t, err)
statefulCount++
}
assert.Greater(statefulCount, 0, "stateful should produce at least one segment")
}
// With temperature=0 (greedy), stateless and stateful should produce identical segments
func TestContext_Process_SameResults_TemperatureZero(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
// Use a single model to avoid environment differences
model, err := whisper.NewModelContext(ModelPath)
require.NoError(t, err)
defer func() { _ = model.Close() }()
// Independent params with temperature=0 for determinism
p := helperNewParams(t, model, func(p *whisper.Parameters) {
p.SetTemperature(0)
p.SetThreads(1)
})
stateless, err := whisper.NewStatelessContext(model, p)
require.NoError(t, err)
defer func() { _ = stateless.Close() }()
stateful, err := whisper.NewStatefulContext(model, p)
require.NoError(t, err)
defer func() { _ = stateful.Close() }()
require.NoError(t, stateless.Process(data, nil, nil, nil))
require.NoError(t, stateful.Process(data, nil, nil, nil))
// Collect segment texts
var segsStateless, segsStateful []string
for {
seg, err := stateless.NextSegment()
if err == io.EOF {
break
}
require.NoError(t, err)
segsStateless = append(segsStateless, seg.Text)
}
for {
seg, err := stateful.NextSegment()
if err == io.EOF {
break
}
require.NoError(t, err)
segsStateful = append(segsStateful, seg.Text)
}
// Both should have at least one segment and be identical
require.Greater(t, len(segsStateless), 0)
require.Greater(t, len(segsStateful), 0)
assert.Equal(len(segsStateful), len(segsStateless))
for i := range segsStateless {
assert.Equal(segsStateless[i], segsStateful[i], "segment %d text differs", i)
}
}
// Model.GetTimings: stateless processing updates model timings (non-zero),
// stateful processing does not (zero timings)
func TestModel_GetTimings_Stateless_NonZero_Stateful_Zero(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
model, err := whisper.NewModelContext(ModelPath)
require.NoError(t, err)
defer func() { _ = model.Close() }()
// Stateless should produce non-zero timings
t.Run("stateless", func(t *testing.T) {
model.ResetTimings()
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatelessContext(model, params)
require.NoError(t, err)
defer func() { _ = ctx.Close() }()
require.NoError(t, ctx.Process(data, nil, nil, nil))
timings, ok := model.GetTimings()
require.True(t, ok, "expected timings to be available after stateless processing")
nonZero := timings.SampleMS > 0 || timings.EncodeMS > 0 || timings.DecodeMS > 0 || timings.BatchdMS > 0 || timings.PromptMS > 0
assert.True(nonZero, "expected at least one non-zero timing after stateless processing: %#v", timings)
})
// Stateful should keep model-level timings at zero
t.Run("stateful", func(t *testing.T) {
model.ResetTimings()
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatefulContext(model, params)
require.NoError(t, err)
defer func() { _ = ctx.Close() }()
require.NoError(t, ctx.Process(data, nil, nil, nil))
timings, ok := model.GetTimings()
// Expect timings present but all zero; if not present at all, treat as zero-equivalent
if ok {
assert.Equal(float32(0), timings.SampleMS)
assert.Equal(float32(0), timings.EncodeMS)
assert.Equal(float32(0), timings.DecodeMS)
assert.Equal(float32(0), timings.BatchdMS)
assert.Equal(float32(0), timings.PromptMS)
} else {
t.Log("timings not available for stateful processing; treating as zero")
}
})
}

View File

@ -1,6 +1,7 @@
package whisper
import (
"fmt"
"io"
"time"
)
@ -20,15 +21,21 @@ type ProgressCallback func(int)
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool
type ParamsConfigure func(*Parameters)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
// Deprecated: Use NewModel implementation struct instead of relying on this interface
type Model interface {
io.Closer
// Return a new speech-to-text context.
// It may return an error is the model is not loaded or closed
// Deprecated: Use NewContext implementation struct instead of relying on this interface
NewContext() (Context, error)
// Return true if the model is multilingual.
// It returns false if the model is not loaded or closed
IsMultilingual() bool
// Return all languages supported.
@ -36,38 +43,91 @@ type Model interface {
}
// Context is the speech recognition context.
// Deprecated: Use NewContext implementation struct instead of relying on this interface
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
DetectedLanguage() string // Get detected language
io.Closer
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
SetAudioCtx(uint) // Set audio encoder context
SetMaxContext(n int) // Set maximum number of text context tokens to store
SetBeamSize(n int) // Set Beam Size
SetEntropyThold(t float32) // Set Entropy threshold
SetInitialPrompt(prompt string) // Set initial prompt
SetTemperature(t float32) // Set temperature
SetTemperatureFallback(t float32) // Set temperature incrementation
// Deprecated: Use Params().SetLanguage() instead
SetLanguage(string) error
SetVAD(v bool)
SetVADModelPath(path string)
SetVADThreshold(t float32)
SetVADMinSpeechMs(ms int)
SetVADMinSilenceMs(ms int)
SetVADMaxSpeechSec(s float32)
SetVADSpeechPadMs(ms int)
SetVADSamplesOverlap(sec float32)
// Deprecated: Use Params().SetTranslate() instead
SetTranslate(bool)
// Deprecated: Use Params().SetSplitOnWord() instead
SetSplitOnWord(bool)
// Deprecated: Use Params().SetThreads() instead
SetThreads(uint)
// Deprecated: Use Params().SetOffset() instead
SetOffset(time.Duration)
// Deprecated: Use Params().SetDuration() instead
SetDuration(time.Duration)
// Deprecated: Use Params().SetTokenThreshold() instead
SetTokenThreshold(float32)
// Deprecated: Use Params().SetTokenSumThreshold() instead
SetTokenSumThreshold(float32)
// Deprecated: Use Params().SetMaxSegmentLength() instead
SetMaxSegmentLength(uint)
// Deprecated: Use Params().SetTokenTimestamps() instead
SetTokenTimestamps(bool)
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
SetMaxTokensPerSegment(uint)
// Deprecated: Use Params().SetAudioCtx() instead
SetAudioCtx(uint)
// Deprecated: Use Params().SetMaxContext() instead
SetMaxContext(int)
// Deprecated: Use Params().SetBeamSize() instead
SetBeamSize(int)
// Deprecated: Use Params().SetEntropyThold() instead
SetEntropyThold(float32)
// Deprecated: Use Params().SetTemperature() instead
SetTemperature(float32)
// Deprecated: Use Params().SetTemperatureFallback() instead
SetTemperatureFallback(float32)
// Deprecated: Use Params().SetInitialPrompt() instead
SetInitialPrompt(string)
// Get language of the context parameters
// Deprecated: Use Params().Language() instead
Language() string
// Deprecated: Use Model().IsMultilingual() instead
IsMultilingual() bool
// Get detected language
DetectedLanguage() string
// Voice Activity Detection (VAD) methods
// Deprecated: Use Params().SetVAD() instead
SetVAD(bool)
// Deprecated: Use Params().SetVADModelPath() instead
SetVADModelPath(string)
// Deprecated: Use Params().SetVADThreshold() instead
SetVADThreshold(float32)
// Deprecated: Use Params().SetVADMinSpeechMs() instead
SetVADMinSpeechMs(int)
// Deprecated: Use Params().SetVADMinSilenceMs() instead
SetVADMinSilenceMs(int)
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
SetVADMaxSpeechSec(float32)
// Deprecated: Use Params().SetVADSpeechPadMs() instead
SetVADSpeechPadMs(int)
// Deprecated: Use Params().SetVADSamplesOverlap() instead
SetVADSamplesOverlap(float32)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
@ -78,19 +138,39 @@ type Context interface {
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
// Deprecated: Use Model().TokenIdentifier().IsBEG() instead
IsBEG(Token) bool
// Timings
// Deprecated: Use Model().TokenIdentifier().IsSOT() instead
IsSOT(Token) bool
// Deprecated: Use Model().TokenIdentifier().IsEOT() instead
IsEOT(Token) bool
// Deprecated: Use Model().TokenIdentifier().IsPREV() instead
IsPREV(Token) bool
// Deprecated: Use Model().TokenIdentifier().IsSOLM() instead
IsSOLM(Token) bool
// Deprecated: Use Model().TokenIdentifier().IsNOT() instead
IsNOT(Token) bool
// Deprecated: Use Model().TokenIdentifier().IsLANG() instead
IsLANG(Token, string) bool
// Deprecated: Use Model().TokenIdentifier().IsText() instead
IsText(Token) bool
// Deprecated: Use Model().PrintTimings() instead
// these are model-level performance metrics
PrintTimings()
// Deprecated: Use Model().ResetTimings() instead
// these are model-level performance metrics
ResetTimings()
// SystemInfo returns the system information
SystemInfo() string
}
@ -107,12 +187,29 @@ type Segment struct {
// The tokens of the segment.
Tokens []Token
// True if the next segment is predicted as a speaker turn (tinydiarize)
// It works only with the diarization supporting models (like small.en-tdrz.bin) with the diarization enabled
// using Parameters.SetDiarize(true)
SpeakerTurnNext bool
}
func (s Segment) String() string {
// foramt: [00:01:39.000 --> 00:01:50.000] And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
return fmt.Sprintf("[%s --> %s] %s", s.Start.Truncate(time.Millisecond), s.End.Truncate(time.Millisecond), s.Text)
}
// Token is a text or special token
type Token struct {
Id int
Text string
P float32
// ID of the token
Id int
// Text of the token
Text string
// Probability of the token
P float32
// Timestamp of the token
Start, End time.Duration
}

View File

@ -0,0 +1,9 @@
package whisper
import low "github.com/ggerganov/whisper.cpp/bindings/go"
// DisableLogs disables all C-side logging from whisper.cpp and ggml.
// Call once early in your program before creating models/contexts.
func DisableLogs() {
low.DisableLogs()
}

View File

@ -3,99 +3,177 @@ package whisper
import (
"fmt"
"os"
"runtime"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
low "github.com/ggerganov/whisper.cpp/bindings/go"
)
///////////////////////////////////////////////////////////////////////////////
// TYPES
type model struct {
path string
ctx *whisper.Context
type ModelContext struct {
path string
ca *ctxAccessor
tokId *tokenIdentifier
}
// Make sure model adheres to the interface
var _ Model = (*model)(nil)
var _ Model = (*ModelContext)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
// Timings is a compact, high-level timing snapshot in milliseconds
type Timings struct {
SampleMS float32
EncodeMS float32
DecodeMS float32
BatchdMS float32
PromptMS float32
}
// Deprecated: Use NewModelContext instead
func New(path string) (Model, error) {
model := new(model)
return NewModelContext(path)
}
// NewModelContext creates a new model context
func NewModelContext(
path string,
) (*ModelContext, error) {
return NewModelContextWithParams(
path,
NewModelContextParams(),
)
}
// NewModelContextWithParams creates a new model context with custom initialization params
func NewModelContextWithParams(
path string,
params ModelContextParams,
) (*ModelContext, error) {
model := new(ModelContext)
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}
// Return success
ctx := low.Whisper_init_with_params(path, params.toLow())
if ctx == nil {
return nil, ErrUnableToLoadModel
}
model.ca = newCtxAccessor(ctx)
model.tokId = newTokenIdentifier(model.ca)
model.path = path
return model, nil
}
func (model *model) Close() error {
if model.ctx != nil {
model.ctx.Whisper_free()
}
// Release resources
model.ctx = nil
// Return success
return nil
func (model *ModelContext) Close() error {
return model.ca.close()
}
///////////////////////////////////////////////////////////////////////////////
// STRINGIFY
func (model *ModelContext) ctxAccessor() *ctxAccessor {
return model.ca
}
func (model *model) String() string {
func (model *ModelContext) String() string {
str := "<whisper.model"
if model.ctx != nil {
if model.ca != nil {
str += fmt.Sprintf(" model=%q", model.path)
}
return str + ">"
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
func (model *ModelContext) IsMultilingual() bool {
ctx, err := model.ca.context()
if err != nil {
return false
}
return ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
str := whisper.Whisper_lang_str(i)
if model.ctx.Whisper_lang_id(str) >= 0 {
func (model *ModelContext) Languages() []string {
ctx, err := model.ca.context()
if err != nil {
return nil
}
result := make([]string, 0, low.Whisper_lang_max_id())
for i := 0; i < low.Whisper_lang_max_id(); i++ {
str := low.Whisper_lang_str(i)
if ctx.Whisper_lang_id(str) >= 0 {
result = append(result, str)
}
}
return result
}
func (model *model) NewContext() (Context, error) {
if model.ctx == nil {
return nil, ErrInternalAppError
// NewContext creates a new speech-to-text context.
// Each context is backed by an isolated whisper_state for safe concurrent processing.
func (model *ModelContext) NewContext() (Context, error) {
// Create new context with default params
params, err := NewParameters(model, SAMPLING_GREEDY, nil)
if err != nil {
return nil, err
}
// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
params.SetPrintRealtime(false)
params.SetPrintTimestamps(false)
params.SetThreads(runtime.NumCPU())
params.SetNoContext(true)
// Return new context
return newContext(model, params)
// Return new context (stateless for backward compatibility with timings)
return NewStatelessContext(
model,
params,
)
}
// PrintTimings prints the model performance timings to stdout.
func (model *ModelContext) PrintTimings() {
ctx, err := model.ca.context()
if err != nil {
return
}
ctx.Whisper_print_timings()
}
// ResetTimings resets the model performance timing counters.
func (model *ModelContext) ResetTimings() {
ctx, err := model.ca.context()
if err != nil {
return
}
ctx.Whisper_reset_timings()
}
// GetTimings returns a compact snapshot of model-level processing timings.
//
// Behavior notes:
// - Stateless contexts (created via ModelContext.NewContext or NewStatelessContext)
// update model-level timings during Process. After a stateless Process call,
// the returned timings are expected to be non-zero (ok == true).
// - Stateful contexts (created via NewStatefulContext) use a per-state backend
// and do not affect model-level timings. After a stateful Process call,
// the returned timings are expected to be zero values (fields equal 0) or
// the call may return ok == false depending on the underlying implementation.
//
// Use ResetTimings before measurement to clear previous values.
func (model *ModelContext) GetTimings() (Timings, bool) {
ctx, err := model.ca.context()
if err != nil {
return Timings{}, false
}
if t, ok := ctx.Whisper_get_timings_go(); ok {
return Timings{
SampleMS: t.SampleMS,
EncodeMS: t.EncodeMS,
DecodeMS: t.DecodeMS,
BatchdMS: t.BatchdMS,
PromptMS: t.PromptMS,
}, true
}
return Timings{}, false
}
func (model *ModelContext) tokenIdentifier() *tokenIdentifier {
return model.tokId
}

View File

@ -0,0 +1,27 @@
package whisper
import (
low "github.com/ggerganov/whisper.cpp/bindings/go"
)
type ModelContextParams struct {
p low.ContextParams
}
func NewModelContextParams() ModelContextParams {
return ModelContextParams{
p: low.Whisper_context_default_params(),
}
}
func (p *ModelContextParams) SetUseGPU(v bool) {
p.p.SetUseGPU(v)
}
func (p *ModelContextParams) SetGPUDevice(n int) {
p.p.SetGPUDevice(n)
}
func (p *ModelContextParams) toLow() low.ContextParams {
return p.p
}

View File

@ -13,7 +13,7 @@ func TestNew(t *testing.T) {
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
defer func() { _ = model.Close() }()
})
@ -42,20 +42,34 @@ func TestNewContext(t *testing.T) {
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
defer func() { _ = model.Close() }()
context, err := model.NewContext()
assert.NoError(err)
assert.NotNil(context)
}
func TestNewContext_ClosedModel(t *testing.T) {
assert := assert.New(t)
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
assert.NoError(model.Close())
context, err := model.NewContext()
assert.ErrorIs(err, whisper.ErrInternalAppError)
assert.ErrorIs(err, whisper.ErrModelClosed)
assert.Nil(context)
}
func TestIsMultilingual(t *testing.T) {
assert := assert.New(t)
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
defer func() { _ = model.Close() }()
isMultilingual := model.IsMultilingual()
@ -71,7 +85,7 @@ func TestLanguages(t *testing.T) {
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
defer func() { _ = model.Close() }()
expectedLanguages := []string{
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl",

View File

@ -0,0 +1,121 @@
package whisper
import (
"runtime"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
// Parameters is a high-level wrapper that implements the Parameters interface
// and delegates to the underlying low-level whisper.Params.
type Parameters struct {
p *whisper.Params
}
func defaultParamsConfigure(params *Parameters) {
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
params.SetPrintRealtime(false)
params.SetPrintTimestamps(false)
// Default behavior backward compatibility
params.SetThreads(uint(runtime.NumCPU()))
params.SetNoContext(true)
}
func NewParameters(
model *ModelContext,
sampling SamplingStrategy,
configure ParamsConfigure,
) (*Parameters, error) {
ctx, err := model.ca.context()
if err != nil {
return nil, ErrModelClosed
}
p := ctx.Whisper_full_default_params(whisper.SamplingStrategy(sampling))
safeParams := &Parameters{
p: &p,
}
defaultParamsConfigure(safeParams)
if configure != nil {
configure(safeParams)
}
return safeParams, nil
}
func (w *Parameters) SetTranslate(v bool) { w.p.SetTranslate(v) }
func (w *Parameters) SetSplitOnWord(v bool) { w.p.SetSplitOnWord(v) }
func (w *Parameters) SetThreads(v uint) { w.p.SetThreads(int(v)) }
func (w *Parameters) SetOffset(d time.Duration) { w.p.SetOffset(int(d.Milliseconds())) }
func (w *Parameters) SetDuration(d time.Duration) { w.p.SetDuration(int(d.Milliseconds())) }
func (w *Parameters) SetTokenThreshold(t float32) { w.p.SetTokenThreshold(t) }
func (w *Parameters) SetTokenSumThreshold(t float32) { w.p.SetTokenSumThreshold(t) }
func (w *Parameters) SetMaxSegmentLength(n uint) { w.p.SetMaxSegmentLength(int(n)) }
func (w *Parameters) SetTokenTimestamps(b bool) { w.p.SetTokenTimestamps(b) }
func (w *Parameters) SetMaxTokensPerSegment(n uint) { w.p.SetMaxTokensPerSegment(int(n)) }
func (w *Parameters) SetAudioCtx(n uint) { w.p.SetAudioCtx(int(n)) }
func (w *Parameters) SetMaxContext(n int) { w.p.SetMaxContext(n) }
func (w *Parameters) SetBeamSize(n int) { w.p.SetBeamSize(n) }
func (w *Parameters) SetEntropyThold(t float32) { w.p.SetEntropyThold(t) }
func (w *Parameters) SetInitialPrompt(prompt string) { w.p.SetInitialPrompt(prompt) }
func (w *Parameters) SetCarryInitialPrompt(v bool) { w.p.SetCarryInitialPrompt(v) }
func (w *Parameters) SetTemperature(t float32) { w.p.SetTemperature(t) }
func (w *Parameters) SetTemperatureFallback(t float32) { w.p.SetTemperatureFallback(t) }
func (w *Parameters) SetNoContext(v bool) { w.p.SetNoContext(v) }
func (w *Parameters) SetPrintSpecial(v bool) { w.p.SetPrintSpecial(v) }
func (w *Parameters) SetPrintProgress(v bool) { w.p.SetPrintProgress(v) }
func (w *Parameters) SetPrintRealtime(v bool) { w.p.SetPrintRealtime(v) }
func (w *Parameters) SetPrintTimestamps(v bool) { w.p.SetPrintTimestamps(v) }
func (w *Parameters) SetDebugMode(v bool) { w.p.SetDebugMode(v) }
// Diarization (tinydiarize)
func (w *Parameters) SetDiarize(v bool) { w.p.SetDiarize(v) }
// Voice Activity Detection (VAD)
func (w *Parameters) SetVAD(v bool) { w.p.SetVAD(v) }
func (w *Parameters) SetVADModelPath(p string) { w.p.SetVADModelPath(p) }
func (w *Parameters) SetVADThreshold(t float32) { w.p.SetVADThreshold(t) }
func (w *Parameters) SetVADMinSpeechMs(ms int) { w.p.SetVADMinSpeechMs(ms) }
func (w *Parameters) SetVADMinSilenceMs(ms int) { w.p.SetVADMinSilenceMs(ms) }
func (w *Parameters) SetVADMaxSpeechSec(s float32) { w.p.SetVADMaxSpeechSec(s) }
func (w *Parameters) SetVADSpeechPadMs(ms int) { w.p.SetVADSpeechPadMs(ms) }
func (w *Parameters) SetVADSamplesOverlap(sec float32) { w.p.SetVADSamplesOverlap(sec) }
func (w *Parameters) SetLanguage(lang string) error {
if lang == "auto" {
return w.p.SetLanguage(-1)
}
id := whisper.Whisper_lang_id_str(lang)
if id < 0 {
return ErrUnsupportedLanguage
}
return w.p.SetLanguage(id)
}
func (w *Parameters) SetSingleSegment(v bool) {
w.p.SetSingleSegment(v)
}
// Getter methods for Parameters interface
func (w *Parameters) Language() string {
id := w.p.Language()
if id == -1 {
return "auto"
}
return whisper.Whisper_lang_str(id)
}
func (w *Parameters) Threads() int {
return w.p.Threads()
}
func (w *Parameters) unsafeParams() (*whisper.Params, error) {
return w.p, nil
}

View File

@ -0,0 +1,438 @@
package whisper
import (
"fmt"
"io"
"runtime"
"strings"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
type StatefulContext struct {
n int
model *ModelContext
st *whisperState
params *Parameters
}
// NewStatefulContext creates a new stateful context
func NewStatefulContext(model *ModelContext, params *Parameters) (*StatefulContext, error) {
if model == nil {
return nil, errModelRequired
}
if params == nil {
return nil, errParametersRequired
}
c := new(StatefulContext)
c.model = model
c.params = params
// allocate isolated state per context
ctx, err := model.ctxAccessor().context()
if err != nil {
return nil, err
}
st := ctx.Whisper_init_state()
if st == nil {
return nil, errUnableToCreateState
}
c.st = newWhisperState(st)
// Return success
return c, nil
}
// DetectedLanguage returns the detected language for the current context data
func (context *StatefulContext) DetectedLanguage() string {
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return ""
}
st, err := context.st.unsafeState()
if err != nil {
return ""
}
return whisper.Whisper_lang_str(
ctx.Whisper_full_lang_id_from_state(
st,
),
)
}
// Close frees the whisper state and marks the context as closed.
func (context *StatefulContext) Close() error {
return context.st.close()
}
// Params returns a high-level parameters wrapper
func (context *StatefulContext) Params() *Parameters {
return context.params
}
// ResetTimings resets the model performance timing counters.
// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics.
func (context *StatefulContext) ResetTimings() {
context.model.ResetTimings()
}
// PrintTimings prints the model performance timings to stdout.
// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics.
func (context *StatefulContext) PrintTimings() {
context.model.PrintTimings()
}
// SystemInfo returns the system information
func (context *StatefulContext) SystemInfo() string {
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
context.params.Threads(),
runtime.NumCPU(),
whisper.Whisper_print_system_info(),
)
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages for this context's state.
func (context *StatefulContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return nil, err
}
st, err := context.st.unsafeState()
if err != nil {
return nil, err
}
langProbs, err := ctx.Whisper_lang_auto_detect_with_state(st, offset_ms, n_threads)
if err != nil {
return nil, err
}
return langProbs, nil
}
// Process new sample data and return any errors
func (context *StatefulContext) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return err
}
// If the callback is defined then we force on single_segment mode
if callNewSegment != nil {
context.params.SetSingleSegment(true)
}
lowLevelParams, err := context.params.unsafeParams()
if err != nil {
return err
}
st, err := context.st.unsafeState()
if err != nil {
return err
}
if err := ctx.Whisper_full_with_state(st, *lowLevelParams, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := ctx.Whisper_full_n_segments_from_state(st)
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegmentFromState(ctx, st, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}
// Return success
return nil
}
// NextSegment returns the next segment from the context buffer
func (context *StatefulContext) NextSegment() (Segment, error) {
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return Segment{}, err
}
st, err := context.st.unsafeState()
if err != nil {
return Segment{}, err
}
if context.n >= ctx.Whisper_full_n_segments_from_state(st) {
return Segment{}, io.EOF
}
result := toSegmentFromState(ctx, st, context.n)
context.n++
return result, nil
}
func (context *StatefulContext) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Token helpers
// Deprecated: Use Model.IsText() instead - token checking is model-specific.
func (context *StatefulContext) IsText(t Token) bool {
result, _ := context.model.tokenIdentifier().IsText(t)
return result
}
// Deprecated: Use Model.IsBEG() instead - token checking is model-specific.
func (context *StatefulContext) IsBEG(t Token) bool {
result, _ := context.model.tokenIdentifier().IsBEG(t)
return result
}
// Deprecated: Use Model.IsSOT() instead - token checking is model-specific.
func (context *StatefulContext) IsSOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsSOT(t)
return result
}
// Deprecated: Use Model.IsEOT() instead - token checking is model-specific.
func (context *StatefulContext) IsEOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsEOT(t)
return result
}
// Deprecated: Use Model.IsPREV() instead - token checking is model-specific.
func (context *StatefulContext) IsPREV(t Token) bool {
result, _ := context.model.tokenIdentifier().IsPREV(t)
return result
}
// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific.
func (context *StatefulContext) IsSOLM(t Token) bool {
result, _ := context.model.tokenIdentifier().IsSOLM(t)
return result
}
// Deprecated: Use Model.IsNOT() instead - token checking is model-specific.
func (context *StatefulContext) IsNOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsNOT(t)
return result
}
func (context *StatefulContext) SetLanguage(lang string) error {
if context.model.ctxAccessor().isClosed() {
// TODO: remove this logic after deprecating the ErrInternalAppError
return ErrModelClosed
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
return context.params.SetLanguage(lang)
}
// Deprecated: Use Model.IsLANG() instead - token checking is model-specific.
func (context *StatefulContext) IsLANG(t Token, lang string) bool {
result, _ := context.model.tokenIdentifier().IsLANG(t, lang)
return result
}
// State-backed helper functions
func toSegmentFromState(ctx *whisper.Context, st *whisper.State, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text_from_state(st, n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0_from_state(st, n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1_from_state(st, n)) * time.Millisecond * 10,
Tokens: toTokensFromState(ctx, st, n),
SpeakerTurnNext: ctx.Whisper_full_get_segment_speaker_turn_next_from_state(st, n),
}
}
func toTokensFromState(ctx *whisper.Context, st *whisper.State, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens_from_state(st, n))
for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data_from_state(st, n, i)
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id_from_state(st, n, i)),
Text: ctx.Whisper_full_get_token_text_from_state(st, n, i),
P: ctx.Whisper_full_get_token_p_from_state(st, n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
}
}
return result
}
// Deprecated: Use Params().Language() instead
func (context *StatefulContext) Language() string {
return context.params.Language()
}
// Deprecated: Use Params().SetAudioCtx() instead
func (context *StatefulContext) SetAudioCtx(n uint) {
context.params.SetAudioCtx(n)
}
// SetBeamSize implements Context.
// Deprecated: Use Params().SetBeamSize() instead
func (context *StatefulContext) SetBeamSize(v int) {
context.params.SetBeamSize(v)
}
// SetDuration implements Context.
// Deprecated: Use Params().SetDuration() instead
func (context *StatefulContext) SetDuration(v time.Duration) {
context.params.SetDuration(v)
}
// SetEntropyThold implements Context.
// Deprecated: Use Params().SetEntropyThold() instead
func (context *StatefulContext) SetEntropyThold(v float32) {
context.params.SetEntropyThold(v)
}
// SetInitialPrompt implements Context.
// Deprecated: Use Params().SetInitialPrompt() instead
func (context *StatefulContext) SetInitialPrompt(v string) {
context.params.SetInitialPrompt(v)
}
// SetMaxContext implements Context.
// Deprecated: Use Params().SetMaxContext() instead
func (context *StatefulContext) SetMaxContext(v int) {
context.params.SetMaxContext(v)
}
// SetMaxSegmentLength implements Context.
// Deprecated: Use Params().SetMaxSegmentLength() instead
func (context *StatefulContext) SetMaxSegmentLength(v uint) {
context.params.SetMaxSegmentLength(v)
}
// SetMaxTokensPerSegment implements Context.
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
func (context *StatefulContext) SetMaxTokensPerSegment(v uint) {
context.params.SetMaxTokensPerSegment(v)
}
// SetOffset implements Context.
// Deprecated: Use Params().SetOffset() instead
func (context *StatefulContext) SetOffset(v time.Duration) {
context.params.SetOffset(v)
}
// SetSplitOnWord implements Context.
// Deprecated: Use Params().SetSplitOnWord() instead
func (context *StatefulContext) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
}
// SetTemperature implements Context.
// Deprecated: Use Params().SetTemperature() instead
func (context *StatefulContext) SetTemperature(v float32) {
context.params.SetTemperature(v)
}
// SetTemperatureFallback implements Context.
// Deprecated: Use Params().SetTemperatureFallback() instead
func (context *StatefulContext) SetTemperatureFallback(v float32) {
context.params.SetTemperatureFallback(v)
}
// SetThreads implements Context.
// Deprecated: Use Params().SetThreads() instead
func (context *StatefulContext) SetThreads(v uint) {
context.params.SetThreads(v)
}
// SetTokenSumThreshold implements Context.
// Deprecated: Use Params().SetTokenSumThreshold() instead
func (context *StatefulContext) SetTokenSumThreshold(v float32) {
context.params.SetTokenSumThreshold(v)
}
// SetTokenThreshold implements Context.
// Deprecated: Use Params().SetTokenThreshold() instead
func (context *StatefulContext) SetTokenThreshold(v float32) {
context.params.SetTokenThreshold(v)
}
// SetTokenTimestamps implements Context.
// Deprecated: Use Params().SetTokenTimestamps() instead
func (context *StatefulContext) SetTokenTimestamps(v bool) {
context.params.SetTokenTimestamps(v)
}
// SetTranslate implements Context.
// Deprecated: Use Params().SetTranslate() instead
func (context *StatefulContext) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// VAD methods - implement Context interface
// Deprecated: Use Params().SetVAD() instead
func (context *StatefulContext) SetVAD(v bool) {
context.params.SetVAD(v)
}
// Deprecated: Use Params().SetVADModelPath() instead
func (context *StatefulContext) SetVADModelPath(path string) {
context.params.SetVADModelPath(path)
}
// Deprecated: Use Params().SetVADThreshold() instead
func (context *StatefulContext) SetVADThreshold(t float32) {
context.params.SetVADThreshold(t)
}
// Deprecated: Use Params().SetVADMinSpeechMs() instead
func (context *StatefulContext) SetVADMinSpeechMs(ms int) {
context.params.SetVADMinSpeechMs(ms)
}
// Deprecated: Use Params().SetVADMinSilenceMs() instead
func (context *StatefulContext) SetVADMinSilenceMs(ms int) {
context.params.SetVADMinSilenceMs(ms)
}
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
func (context *StatefulContext) SetVADMaxSpeechSec(s float32) {
context.params.SetVADMaxSpeechSec(s)
}
// Deprecated: Use Params().SetVADSpeechPadMs() instead
func (context *StatefulContext) SetVADSpeechPadMs(ms int) {
context.params.SetVADSpeechPadMs(ms)
}
// Deprecated: Use Params().SetVADSamplesOverlap() instead
func (context *StatefulContext) SetVADSamplesOverlap(sec float32) {
context.params.SetVADSamplesOverlap(sec)
}
// Make stateful context compatible with the old deprecated interface for
// the simple migration into multi-threaded processing.
var _ Context = (*StatefulContext)(nil)

View File

@ -0,0 +1,81 @@
package whisper_test
import (
"os"
"sync"
"testing"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
assert "github.com/stretchr/testify/assert"
)
// Stateful-specific: parallel processing supported
func TestContext_Parallel_DifferentInputs_Stateful(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
data := helperLoadSample(t, SamplePath)
assert.Greater(len(data), 10)
// Create half-sample (second half)
half := make([]float32, len(data)/2)
copy(half, data[len(data)/2:])
model, err := whisper.NewModelContext(ModelPath)
assert.NoError(err)
defer func() { _ = model.Close() }()
params1 := helperNewParams(t, model, nil)
params2 := helperNewParams(t, model, nil)
ctx1, err := whisper.NewStatefulContext(model, params1)
assert.NoError(err)
defer func() { _ = ctx1.Close() }()
ctx2, err := whisper.NewStatefulContext(model, params2)
assert.NoError(err)
defer func() { _ = ctx2.Close() }()
var wg sync.WaitGroup
var first1, first2 string
var e1, e2 error
wg.Add(2)
go func() {
defer wg.Done()
e1 = ctx1.Process(data, nil, nil, nil)
if e1 == nil {
seg, err := ctx1.NextSegment()
if err == nil {
first1 = seg.Text
} else {
e1 = err
}
}
}()
go func() {
defer wg.Done()
e2 = ctx2.Process(half, nil, nil, nil)
if e2 == nil {
seg, err := ctx2.NextSegment()
if err == nil {
first2 = seg.Text
} else {
e2 = err
}
}
}()
wg.Wait()
assert.NoError(e1)
assert.NoError(e2)
assert.NotEmpty(first1)
assert.NotEmpty(first2)
assert.NotEqual(first1, first2, "first segments should differ for different inputs")
}

View File

@ -0,0 +1,418 @@
package whisper
import (
"fmt"
"io"
"runtime"
"strings"
"time"
// Bindings
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
)
type StatelessContext struct {
n int
model *ModelContext
params *Parameters
closed bool
}
// NewStatelessContext creates a new stateless context backed by the model's context
func NewStatelessContext(model *ModelContext, params *Parameters) (*StatelessContext, error) {
if model == nil {
return nil, errModelRequired
}
if params == nil {
return nil, errParametersRequired
}
// Ensure model context is available
if _, err := model.ctxAccessor().context(); err != nil {
return nil, err
}
c := new(StatelessContext)
c.model = model
c.params = params
return c, nil
}
// DetectedLanguage returns the detected language for the current context data
func (context *StatelessContext) DetectedLanguage() string {
if context.closed {
return ""
}
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return ""
}
return whisper.Whisper_lang_str(ctx.Whisper_full_lang_id())
}
// Close marks the context as closed.
func (context *StatelessContext) Close() error {
context.closed = true
return nil
}
// Params returns a high-level parameters wrapper
func (context *StatelessContext) Params() *Parameters {
return context.params
}
// ResetTimings resets the model performance timing counters.
// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics.
func (context *StatelessContext) ResetTimings() {
context.model.ResetTimings()
}
// PrintTimings prints the model performance timings to stdout.
// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics.
func (context *StatelessContext) PrintTimings() {
context.model.PrintTimings()
}
// SystemInfo returns the system information
func (context *StatelessContext) SystemInfo() string {
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
context.params.Threads(),
runtime.NumCPU(),
whisper.Whisper_print_system_info(),
)
}
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages for this context.
func (context *StatelessContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
if context.closed {
return nil, ErrModelClosed
}
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return nil, err
}
langProbs, err := ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
if err != nil {
return nil, err
}
return langProbs, nil
}
// Process new sample data and return any errors
func (context *StatelessContext) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
if context.closed {
return ErrModelClosed
}
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return err
}
// Concurrency guard: prevent concurrent stateless processing on shared model ctx
k := modelKey(context.model)
if !gate().Acquire(k) {
return ErrStatelessBusy
}
defer gate().Release(k)
// If the callback is defined then we force on single_segment mode
if callNewSegment != nil {
context.params.SetSingleSegment(true)
}
lowLevelParams, err := context.params.unsafeParams()
if err != nil {
return err
}
if err := ctx.Whisper_full(*lowLevelParams, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegmentFromContext(ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}
// Return success
return nil
}
// NextSegment returns the next segment from the context buffer
func (context *StatelessContext) NextSegment() (Segment, error) {
if context.closed {
return Segment{}, ErrModelClosed
}
ctx, err := context.model.ctxAccessor().context()
if err != nil {
return Segment{}, err
}
if context.n >= ctx.Whisper_full_n_segments() {
return Segment{}, io.EOF
}
result := toSegmentFromContext(ctx, context.n)
context.n++
return result, nil
}
func (context *StatelessContext) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Token helpers
// Deprecated: Use Model.IsText() instead - token checking is model-specific.
func (context *StatelessContext) IsText(t Token) bool {
result, _ := context.model.tokenIdentifier().IsText(t)
return result
}
// Deprecated: Use Model.IsBEG() instead - token checking is model-specific.
func (context *StatelessContext) IsBEG(t Token) bool {
result, _ := context.model.tokenIdentifier().IsBEG(t)
return result
}
// Deprecated: Use Model.IsSOT() instead - token checking is model-specific.
func (context *StatelessContext) IsSOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsSOT(t)
return result
}
// Deprecated: Use Model.IsEOT() instead - token checking is model-specific.
func (context *StatelessContext) IsEOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsEOT(t)
return result
}
// Deprecated: Use Model.IsPREV() instead - token checking is model-specific.
func (context *StatelessContext) IsPREV(t Token) bool {
result, _ := context.model.tokenIdentifier().IsPREV(t)
return result
}
// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific.
func (context *StatelessContext) IsSOLM(t Token) bool {
result, _ := context.model.tokenIdentifier().IsSOLM(t)
return result
}
// Deprecated: Use Model.IsNOT() instead - token checking is model-specific.
func (context *StatelessContext) IsNOT(t Token) bool {
result, _ := context.model.tokenIdentifier().IsNOT(t)
return result
}
func (context *StatelessContext) SetLanguage(lang string) error {
if context.closed || context.model.ctxAccessor().isClosed() {
return ErrModelClosed
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
return context.params.SetLanguage(lang)
}
// Deprecated: Use Model.IsLANG() instead - token checking is model-specific.
func (context *StatelessContext) IsLANG(t Token, lang string) bool {
result, _ := context.model.tokenIdentifier().IsLANG(t, lang)
return result
}
// Context-backed helper functions
func toSegmentFromContext(ctx *whisper.Context, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokensFromContext(ctx, n),
SpeakerTurnNext: false, // speaker turn available only with state-backed accessors
}
}
func toTokensFromContext(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
data := ctx.Whisper_full_get_token_data(n, i)
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: ctx.Whisper_full_get_token_text(n, i),
P: ctx.Whisper_full_get_token_p(n, i),
Start: time.Duration(data.T0()) * time.Millisecond * 10,
End: time.Duration(data.T1()) * time.Millisecond * 10,
}
}
return result
}
// Deprecated: Use Params().Language() instead
func (context *StatelessContext) Language() string {
return context.params.Language()
}
// Deprecated: Use Params().SetAudioCtx() instead
func (context *StatelessContext) SetAudioCtx(n uint) {
context.params.SetAudioCtx(n)
}
// SetBeamSize implements Context.
// Deprecated: Use Params().SetBeamSize() instead
func (context *StatelessContext) SetBeamSize(v int) {
context.params.SetBeamSize(v)
}
// SetDuration implements Context.
// Deprecated: Use Params().SetDuration() instead
func (context *StatelessContext) SetDuration(v time.Duration) {
context.params.SetDuration(v)
}
// SetEntropyThold implements Context.
// Deprecated: Use Params().SetEntropyThold() instead
func (context *StatelessContext) SetEntropyThold(v float32) {
context.params.SetEntropyThold(v)
}
// SetInitialPrompt implements Context.
// Deprecated: Use Params().SetInitialPrompt() instead
func (context *StatelessContext) SetInitialPrompt(v string) {
context.params.SetInitialPrompt(v)
}
// SetMaxContext implements Context.
// Deprecated: Use Params().SetMaxContext() instead
func (context *StatelessContext) SetMaxContext(v int) {
context.params.SetMaxContext(v)
}
// SetMaxSegmentLength implements Context.
// Deprecated: Use Params().SetMaxSegmentLength() instead
func (context *StatelessContext) SetMaxSegmentLength(v uint) {
context.params.SetMaxSegmentLength(v)
}
// SetMaxTokensPerSegment implements Context.
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
func (context *StatelessContext) SetMaxTokensPerSegment(v uint) {
context.params.SetMaxTokensPerSegment(v)
}
// SetOffset implements Context.
// Deprecated: Use Params().SetOffset() instead
func (context *StatelessContext) SetOffset(v time.Duration) {
context.params.SetOffset(v)
}
// SetSplitOnWord implements Context.
// Deprecated: Use Params().SetSplitOnWord() instead
func (context *StatelessContext) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
}
// SetTemperature implements Context.
// Deprecated: Use Params().SetTemperature() instead
func (context *StatelessContext) SetTemperature(v float32) {
context.params.SetTemperature(v)
}
// SetTemperatureFallback implements Context.
// Deprecated: Use Params().SetTemperatureFallback() instead
func (context *StatelessContext) SetTemperatureFallback(v float32) {
context.params.SetTemperatureFallback(v)
}
// SetThreads implements Context.
// Deprecated: Use Params().SetThreads() instead
func (context *StatelessContext) SetThreads(v uint) {
context.params.SetThreads(v)
}
// SetTokenSumThreshold implements Context.
// Deprecated: Use Params().SetTokenSumThreshold() instead
func (context *StatelessContext) SetTokenSumThreshold(v float32) {
context.params.SetTokenSumThreshold(v)
}
// SetTokenThreshold implements Context.
// Deprecated: Use Params().SetTokenThreshold() instead
func (context *StatelessContext) SetTokenThreshold(v float32) {
context.params.SetTokenThreshold(v)
}
// SetTokenTimestamps implements Context.
// Deprecated: Use Params().SetTokenTimestamps() instead
func (context *StatelessContext) SetTokenTimestamps(v bool) {
context.params.SetTokenTimestamps(v)
}
// SetTranslate implements Context.
// Deprecated: Use Params().SetTranslate() instead
func (context *StatelessContext) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// VAD methods - implement Context interface
// Deprecated: Use Params().SetVAD() instead
func (context *StatelessContext) SetVAD(v bool) {
context.params.SetVAD(v)
}
// Deprecated: Use Params().SetVADModelPath() instead
func (context *StatelessContext) SetVADModelPath(path string) {
context.params.SetVADModelPath(path)
}
// Deprecated: Use Params().SetVADThreshold() instead
func (context *StatelessContext) SetVADThreshold(t float32) {
context.params.SetVADThreshold(t)
}
// Deprecated: Use Params().SetVADMinSpeechMs() instead
func (context *StatelessContext) SetVADMinSpeechMs(ms int) {
context.params.SetVADMinSpeechMs(ms)
}
// Deprecated: Use Params().SetVADMinSilenceMs() instead
func (context *StatelessContext) SetVADMinSilenceMs(ms int) {
context.params.SetVADMinSilenceMs(ms)
}
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
func (context *StatelessContext) SetVADMaxSpeechSec(s float32) {
context.params.SetVADMaxSpeechSec(s)
}
// Deprecated: Use Params().SetVADSpeechPadMs() instead
func (context *StatelessContext) SetVADSpeechPadMs(ms int) {
context.params.SetVADSpeechPadMs(ms)
}
// Deprecated: Use Params().SetVADSamplesOverlap() instead
func (context *StatelessContext) SetVADSamplesOverlap(sec float32) {
context.params.SetVADSamplesOverlap(sec)
}
var _ Context = (*StatelessContext)(nil)

View File

@ -0,0 +1,52 @@
package whisper_test
import (
"sync"
"testing"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
assert "github.com/stretchr/testify/assert"
)
// Ensure stateless contexts cannot process in parallel without isolation
func TestStatelessContext_NotParallelSafe(t *testing.T) {
data := helperLoadSample(t, SamplePath)
model, closeModel := helperNewModelContext(t)
defer closeModel()
params := helperNewParams(t, model, nil)
// Create two stateless contexts sharing the same underlying model context
ctx1, err := whisper.NewStatelessContext(model, params)
assert.NoError(t, err)
defer func() { _ = ctx1.Close() }()
ctx2, err := whisper.NewStatelessContext(model, params)
assert.NoError(t, err)
defer func() { _ = ctx2.Close() }()
// Run both in parallel - expect a panic or error from underlying whisper_full
// We capture panics to assert the behavior.
var wg sync.WaitGroup
wg.Add(2)
var err1, err2 error
go func() {
defer wg.Done()
err1 = ctx1.Process(data, nil, nil, nil)
}()
go func() {
defer wg.Done()
err2 = ctx2.Process(data, nil, nil, nil)
}()
wg.Wait()
// At least one should return ErrStatelessBusy
if err1 != whisper.ErrStatelessBusy && err2 != whisper.ErrStatelessBusy {
t.Fatalf("expected ErrStatelessBusy when processing in parallel with StatelessContext, got err1=%v err2=%v", err1, err2)
}
}

View File

@ -0,0 +1,129 @@
package whisper_test
import (
"os"
"testing"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav"
)
func helperLoadSample(tb testing.TB, path string) []float32 {
tb.Helper()
fh, err := os.Open(path)
if err != nil {
tb.Fatalf("open sample: %v", err)
}
defer func() { _ = fh.Close() }()
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
if err != nil {
tb.Fatalf("decode wav: %v", err)
}
if dec.NumChans != 1 {
tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans)
}
return buf.AsFloat32Buffer().Data
}
// helperLoadSampleWithMeta loads wav and returns samples with sample rate and channels
func helperLoadSampleWithMeta(tb testing.TB, path string) ([]float32, int, int) {
tb.Helper()
fh, err := os.Open(path)
if err != nil {
tb.Fatalf("open sample: %v", err)
}
defer func() { _ = fh.Close() }()
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
if err != nil {
tb.Fatalf("decode wav: %v", err)
}
if dec.NumChans != 1 {
tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans)
}
return buf.AsFloat32Buffer().Data, int(dec.SampleRate), int(dec.NumChans)
}
func helperNewModel(t *testing.T) (whisper.Model, func()) {
t.Helper()
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
model, err := whisper.New(ModelPath)
if err != nil {
t.Fatalf("load model: %v", err)
}
return model, func() { _ = model.Close() }
}
func helperNewModelContext(t *testing.T) (*whisper.ModelContext, func()) {
t.Helper()
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
model, err := whisper.NewModelContext(ModelPath)
if err != nil {
t.Fatalf("load model ctx: %v", err)
}
return model, func() { _ = model.Close() }
}
func helperNewParams(t *testing.T, model *whisper.ModelContext, configure whisper.ParamsConfigure) *whisper.Parameters {
t.Helper()
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, configure)
if err != nil {
t.Fatalf("new params: %v", err)
}
return params
}
func helperProcessOnce(t *testing.T, ctx whisper.Context, data []float32) {
t.Helper()
if err := ctx.Process(data, nil, nil, nil); err != nil {
t.Fatalf("process: %v", err)
}
}
func helperFirstSegmentText(t *testing.T, ctx whisper.Context) string {
t.Helper()
seg, err := ctx.NextSegment()
if err != nil {
t.Fatalf("next segment: %v", err)
}
return seg.Text
}
// helperNewStatelessContext creates a fresh stateless context and returns a cleanup func
func helperNewStatelessContext(t *testing.T) (whisper.Context, func()) {
t.Helper()
model, closeModel := helperNewModelContext(t)
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatelessContext(model, params)
if err != nil {
t.Fatalf("new stateless context: %v", err)
}
cleanup := func() {
_ = ctx.Close()
closeModel()
}
return ctx, cleanup
}
// helperNewStatefulContext creates a fresh stateful context and returns a cleanup func
func helperNewStatefulContext(t *testing.T) (whisper.Context, func()) {
t.Helper()
model, closeModel := helperNewModelContext(t)
params := helperNewParams(t, model, nil)
ctx, err := whisper.NewStatefulContext(model, params)
if err != nil {
t.Fatalf("new stateful context: %v", err)
}
cleanup := func() {
_ = ctx.Close()
closeModel()
}
return ctx, cleanup
}

View File

@ -0,0 +1,115 @@
package whisper
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
type tokenIdentifier struct {
ctx *ctxAccessor
}
func newTokenIdentifier(whisperContext *ctxAccessor) *tokenIdentifier {
return &tokenIdentifier{
ctx: whisperContext,
}
}
// Token type checking methods (model-specific vocabulary)
func (ti *tokenIdentifier) IsBEG(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_beg(), nil
}
func (ti *tokenIdentifier) IsEOT(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_eot(), nil
}
func (ti *tokenIdentifier) IsSOT(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_sot(), nil
}
func (ti *tokenIdentifier) IsPREV(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_prev(), nil
}
func (ti *tokenIdentifier) IsSOLM(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_solm(), nil
}
func (ti *tokenIdentifier) IsNOT(t Token) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
return whisper.Token(t.Id) == ctx.Whisper_token_not(), nil
}
func (ti *tokenIdentifier) IsLANG(t Token, lang string) (bool, error) {
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
if id := ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == ctx.Whisper_token_lang(id), nil
}
return false, nil
}
func (ti *tokenIdentifier) IsText(t Token) (bool, error) {
// Check if it's any of the special tokens
if isBeg, _ := ti.IsBEG(t); isBeg {
return false, nil
}
if isSot, _ := ti.IsSOT(t); isSot {
return false, nil
}
ctx, err := ti.ctx.context()
if err != nil {
return false, err
}
if whisper.Token(t.Id) >= ctx.Whisper_token_eot() {
return false, nil
}
if isPrev, _ := ti.IsPREV(t); isPrev {
return false, nil
}
if isSolm, _ := ti.IsSOLM(t); isSolm {
return false, nil
}
if isNot, _ := ti.IsNOT(t); isNot {
return false, nil
}
return true, nil
}

View File

@ -1,6 +1,16 @@
package whisper_test
import (
"os"
"testing"
)
const (
ModelPath = "../../models/ggml-small.en.bin"
ModelPath = "../../models/ggml-tiny.en.bin"
SamplePath = "../../samples/jfk.wav"
)
func TestMain(m *testing.M) {
// whisper.DisableLogs()
os.Exit(m.Run())
}

View File

@ -0,0 +1,36 @@
package whisper
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
type ctxAccessor struct {
ctx *whisper.Context
}
func newCtxAccessor(ctx *whisper.Context) *ctxAccessor {
return &ctxAccessor{
ctx: ctx,
}
}
func (ctx *ctxAccessor) close() error {
if ctx.ctx == nil {
return nil
}
ctx.ctx.Whisper_free()
ctx.ctx = nil
return nil
}
func (ctx *ctxAccessor) isClosed() bool {
return ctx.ctx == nil
}
func (ctx *ctxAccessor) context() (*whisper.Context, error) {
if ctx.isClosed() {
return nil, ErrModelClosed
}
return ctx.ctx, nil
}

View File

@ -0,0 +1,85 @@
package whisper
import (
"os"
"testing"
w "github.com/ggerganov/whisper.cpp/bindings/go"
assert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testModelPathCtx = "../../models/ggml-tiny.en.bin"
func TestWhisperCtx_NilWrapper(t *testing.T) {
wctx := newCtxAccessor(nil)
assert.True(t, wctx.isClosed())
raw, err := wctx.context()
assert.Nil(t, raw)
require.ErrorIs(t, err, ErrModelClosed)
require.NoError(t, wctx.close())
// idempotent
require.NoError(t, wctx.close())
}
func TestWhisperCtx_Lifecycle(t *testing.T) {
if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", testModelPathCtx)
}
raw := w.Whisper_init(testModelPathCtx)
require.NotNil(t, raw)
wctx := newCtxAccessor(raw)
assert.False(t, wctx.isClosed())
got, err := wctx.context()
require.NoError(t, err)
require.NotNil(t, got)
// close frees underlying ctx and marks closed
require.NoError(t, wctx.close())
assert.True(t, wctx.isClosed())
got, err = wctx.context()
assert.Nil(t, got)
require.ErrorIs(t, err, ErrModelClosed)
// idempotent
require.NoError(t, wctx.close())
// no further free; raw already freed by wctx.Close()
}
func TestWhisperCtx_FromModelLifecycle(t *testing.T) {
if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", testModelPathCtx)
}
modelNew, err := New(testModelPathCtx)
require.NoError(t, err)
require.NotNil(t, modelNew)
model := modelNew.(*ModelContext)
wc := model.ctxAccessor()
require.NotNil(t, wc)
// Should be usable before model.Close
raw, err := wc.context()
require.NoError(t, err)
require.NotNil(t, raw)
// Close model should close underlying context
require.NoError(t, model.Close())
assert.True(t, wc.isClosed())
raw, err = wc.context()
assert.Nil(t, raw)
require.ErrorIs(t, err, ErrModelClosed)
// Idempotent close on wrapper
require.NoError(t, wc.close())
}

View File

@ -0,0 +1,32 @@
package whisper
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
type whisperState struct {
state *whisper.State
}
func newWhisperState(state *whisper.State) *whisperState {
return &whisperState{
state: state,
}
}
func (s *whisperState) close() error {
if s.state == nil {
return nil
}
s.state.Whisper_free_state()
s.state = nil
return nil
}
func (s *whisperState) unsafeState() (*whisper.State, error) {
if s.state == nil {
return nil, ErrModelClosed
}
return s.state, nil
}

View File

@ -0,0 +1,53 @@
package whisper
import (
"os"
"testing"
w "github.com/ggerganov/whisper.cpp/bindings/go"
assert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testModelPathState = "../../models/ggml-tiny.en.bin"
func TestWhisperState_NilWrapper(t *testing.T) {
ws := newWhisperState(nil)
state, err := ws.unsafeState()
assert.Nil(t, state)
require.ErrorIs(t, err, ErrModelClosed)
require.NoError(t, ws.close())
// idempotent
require.NoError(t, ws.close())
}
func TestWhisperState_Lifecycle(t *testing.T) {
if _, err := os.Stat(testModelPathState); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", testModelPathState)
}
ctx := w.Whisper_init(testModelPathState)
require.NotNil(t, ctx)
defer ctx.Whisper_free()
state := ctx.Whisper_init_state()
require.NotNil(t, state)
ws := newWhisperState(state)
got, err := ws.unsafeState()
require.NoError(t, err)
require.NotNil(t, got)
// close frees underlying state and marks closed
require.NoError(t, ws.close())
got, err = ws.unsafeState()
assert.Nil(t, got)
require.ErrorIs(t, err, ErrModelClosed)
// idempotent
require.NoError(t, ws.close())
}

View File

@ -2,6 +2,7 @@ package whisper
import (
"errors"
"sync"
"unsafe"
)
@ -14,6 +15,7 @@ import (
#cgo darwin LDFLAGS: -lggml-metal -lggml-blas
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
#include <whisper.h>
#include <ggml.h>
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
@ -59,6 +61,22 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
params.progress_callback_user_data = (void*)(ctx);
return params;
}
// Disable all C-side logging (whisper.cpp and ggml)
static void go_cb_log_disable(enum ggml_log_level level, const char * text, void * user_data) {
(void) level; (void) text; (void) user_data;
}
static void whisper_log_disable_all(void) {
ggml_log_set(go_cb_log_disable, NULL);
whisper_log_set(go_cb_log_disable, NULL);
}
// Enable default logging (stdout) for whisper.cpp and ggml
static void whisper_log_enable_default(void) {
ggml_log_set(NULL, NULL);
whisper_log_set(NULL, NULL);
}
*/
import "C"
@ -67,10 +85,13 @@ import "C"
type (
Context C.struct_whisper_context
State C.struct_whisper_state
Token C.whisper_token
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
Timings C.struct_whisper_timings
ContextParams C.struct_whisper_context_params
)
///////////////////////////////////////////////////////////////////////////////
@ -96,6 +117,12 @@ var (
ErrInvalidLanguage = errors.New("invalid language")
)
// DisableLogs disables all logging coming from the C libraries (whisper.cpp and ggml).
// Call once early in program startup if you want to silence device/backend prints.
func DisableLogs() {
C.whisper_log_disable_all()
}
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
@ -111,11 +138,54 @@ func Whisper_init(path string) *Context {
}
}
// Whisper_context_default_params returns default model context params
func Whisper_context_default_params() ContextParams {
return ContextParams(C.whisper_context_default_params())
}
// SetUseGPU enables or disables GPU acceleration on the model context (if available)
func (p *ContextParams) SetUseGPU(v bool) {
if v {
p.use_gpu = C.bool(true)
} else {
p.use_gpu = C.bool(false)
}
}
// SetGPUDevice selects the GPU device index for the model context (CUDA)
func (p *ContextParams) SetGPUDevice(n int) {
p.gpu_device = C.int(n)
}
// Whisper_init_with_params allocates and initializes a model using custom context params
func Whisper_init_with_params(path string, params ContextParams) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}
// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
}
// Allocates a new state associated with the context. Returns nil on failure.
func (ctx *Context) Whisper_init_state() *State {
if s := C.whisper_init_state((*C.struct_whisper_context)(ctx)); s != nil {
return (*State)(s)
}
return nil
}
// Frees all memory allocated by the state.
func (s *State) Whisper_free_state() {
C.whisper_free_state((*C.struct_whisper_state)(s))
}
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
@ -126,6 +196,15 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
}
}
// Convert RAW PCM audio to log mel spectrogram into the provided state.
func (ctx *Context) Whisper_pcm_to_mel_with_state(state *State, data []float32, threads int) error {
if C.whisper_pcm_to_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
@ -137,6 +216,15 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
}
}
// Set a custom log mel spectrogram into the provided state.
func (ctx *Context) Whisper_set_mel_with_state(state *State, data []float32, n_mel int) error {
if C.whisper_set_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
@ -148,6 +236,15 @@ func (ctx *Context) Whisper_encode(offset, threads int) error {
}
}
// Run the Whisper encoder using the provided state.
func (ctx *Context) Whisper_encode_with_state(state *State, offset, threads int) error {
if C.whisper_encode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
@ -160,6 +257,15 @@ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
}
}
// Run the Whisper decoder using the provided state.
func (ctx *Context) Whisper_decode_with_state(state *State, tokens []Token, past, threads int) error {
if C.whisper_decode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
@ -181,6 +287,10 @@ func (ctx *Context) Whisper_lang_id(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
func Whisper_lang_id_str(lang string) int {
return int(C.whisper_lang_id(C.CString(lang)))
}
// Largest language id (i.e. number of available languages - 1)
func Whisper_lang_max_id() int {
return int(C.whisper_lang_max_id())
@ -205,6 +315,16 @@ func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float3
}
}
// Use mel data at offset_ms to auto-detect language using the provided state.
func (ctx *Context) Whisper_lang_auto_detect_with_state(state *State, offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed
} else {
return probs, nil
}
}
func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
}
@ -290,6 +410,32 @@ func (ctx *Context) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
}
// TimingsGo is a Go-friendly copy of whisper_timings
type TimingsGo struct {
SampleMS float32
EncodeMS float32
DecodeMS float32
BatchdMS float32
PromptMS float32
}
// Whisper_get_timings_go retrieves timing counters and converts them to TimingsGo
func (ctx *Context) Whisper_get_timings_go() (TimingsGo, bool) {
t := C.whisper_get_timings((*C.struct_whisper_context)(ctx))
if t == nil {
return TimingsGo{}, false
}
// The C struct is 5 consecutive floats; reinterpret and copy
arr := (*[5]C.float)(unsafe.Pointer(t))
return TimingsGo{
SampleMS: float32(arr[0]),
EncodeMS: float32(arr[1]),
DecodeMS: float32(arr[2]),
BatchdMS: float32(arr[3]),
PromptMS: float32(arr[4]),
}, true
}
// Print system information
func Whisper_print_system_info() string {
return C.GoString(C.whisper_print_system_info())
@ -323,6 +469,28 @@ func (ctx *Context) Whisper_full(
}
}
// Run the entire model using the provided state: PCM -> mel -> encoder -> decoder -> text
func (ctx *Context) Whisper_full_with_state(
state *State,
params Params,
samples []float32,
encoderBeginCallback func() bool,
newSegmentCallback func(int),
progressCallback func(int),
) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
registerProgressCallback(ctx, progressCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
defer registerProgressCallback(ctx, nil)
if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
return ErrConversionFailed
}
}
// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
@ -357,102 +525,157 @@ func (ctx *Context) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
}
func (ctx *Context) Whisper_full_n_segments_from_state(state *State) int {
return int(C.whisper_full_n_segments_from_state((*C.struct_whisper_state)(state)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
}
func (ctx *Context) Whisper_full_get_segment_t0_from_state(state *State, segment int) int64 {
return int64(C.whisper_full_get_segment_t0_from_state((*C.struct_whisper_state)(state), C.int(segment)))
}
// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
}
func (ctx *Context) Whisper_full_get_segment_t1_from_state(state *State, segment int) int64 {
return int64(C.whisper_full_get_segment_t1_from_state((*C.struct_whisper_state)(state), C.int(segment)))
}
// Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
}
func (ctx *Context) Whisper_full_get_segment_text_from_state(state *State, segment int) string {
return C.GoString(C.whisper_full_get_segment_text_from_state((*C.struct_whisper_state)(state), C.int(segment)))
}
// Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
}
func (ctx *Context) Whisper_full_n_tokens_from_state(state *State, segment int) int {
return int(C.whisper_full_n_tokens_from_state((*C.struct_whisper_state)(state), C.int(segment)))
}
// Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
func (ctx *Context) Whisper_full_get_token_text_from_state(state *State, segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text_from_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}
// Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
func (ctx *Context) Whisper_full_get_token_id_from_state(state *State, segment int, token int) Token {
return Token(C.whisper_full_get_token_id_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}
// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
func (ctx *Context) Whisper_full_get_token_data_from_state(state *State, segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}
// Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
}
func (ctx *Context) Whisper_full_get_token_p_from_state(state *State, segment int, token int) float32 {
return float32(C.whisper_full_get_token_p_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}
func (ctx *Context) Whisper_full_lang_id_from_state(state *State) int {
return int(C.whisper_full_lang_id_from_state((*C.struct_whisper_state)(state)))
}
func (ctx *Context) Whisper_n_len_from_state(state *State) int {
return int(C.whisper_n_len_from_state((*C.struct_whisper_state)(state)))
}
func (ctx *Context) Whisper_get_logits_from_state(state *State) []float32 {
return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_logits_from_state((*C.struct_whisper_state)(state))))[:ctx.Whisper_n_vocab()]
}
// Get whether the next segment is predicted as a speaker turn (tinydiarize)
func (ctx *Context) Whisper_full_get_segment_speaker_turn_next_from_state(state *State, segment int) bool {
return bool(C.whisper_full_get_segment_speaker_turn_next_from_state((*C.struct_whisper_state)(state), C.int(segment)))
}
///////////////////////////////////////////////////////////////////////////////
// CALLBACKS
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbProgress = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
cbNewSegment sync.Map // map[unsafe.Pointer]func(int)
cbProgress sync.Map // map[unsafe.Pointer]func(int)
cbEncoderBegin sync.Map // map[unsafe.Pointer]func() bool
)
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
k := unsafe.Pointer(ctx)
if fn == nil {
delete(cbNewSegment, unsafe.Pointer(ctx))
cbNewSegment.Delete(k)
} else {
cbNewSegment[unsafe.Pointer(ctx)] = fn
cbNewSegment.Store(k, fn)
}
}
func registerProgressCallback(ctx *Context, fn func(int)) {
k := unsafe.Pointer(ctx)
if fn == nil {
delete(cbProgress, unsafe.Pointer(ctx))
cbProgress.Delete(k)
} else {
cbProgress[unsafe.Pointer(ctx)] = fn
cbProgress.Store(k, fn)
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
k := unsafe.Pointer(ctx)
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
cbEncoderBegin.Delete(k)
} else {
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
cbEncoderBegin.Store(k, fn)
}
}
//export callNewSegment
func callNewSegment(user_data unsafe.Pointer, new C.int) {
if fn, ok := cbNewSegment[user_data]; ok {
fn(int(new))
if v, ok := cbNewSegment.Load(user_data); ok {
v.(func(int))(int(new))
}
}
//export callProgress
func callProgress(user_data unsafe.Pointer, progress C.int) {
if fn, ok := cbProgress[user_data]; ok {
fn(int(progress))
if v, ok := cbProgress.Load(user_data); ok {
v.(func(int))(int(progress))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
if fn() {
if v, ok := cbEncoderBegin.Load(user_data); ok {
if v.(func() bool)() {
return C.bool(true)
} else {
return C.bool(false)
}
return C.bool(false)
}
return true
}

View File

@ -1,8 +1,10 @@
package whisper_test
import (
"errors"
"os"
"runtime"
"sync"
"testing"
"time"
@ -13,10 +15,15 @@ import (
)
const (
ModelPath = "models/ggml-small.en.bin"
ModelPath = "models/ggml-tiny.en.bin"
SamplePath = "samples/jfk.wav"
)
func TestMain(m *testing.M) {
// whisper.DisableLogs() // temporarily disabled to see error messages
os.Exit(m.Run())
}
func Test_Whisper_000(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
@ -39,7 +46,7 @@ func Test_Whisper_001(t *testing.T) {
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
defer func() { _ = fh.Close() }()
// Read samples
d := wav.NewDecoder(fh)
@ -89,7 +96,7 @@ func Test_Whisper_003(t *testing.T) {
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
defer func() { _ = fh.Close() }()
// Read samples
d := wav.NewDecoder(fh)
@ -111,3 +118,157 @@ func Test_Whisper_003(t *testing.T) {
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
}
}
func Test_Whisper_State_Init_Free(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
state := ctx.Whisper_init_state()
assert.NotNil(state)
state.Whisper_free_state()
}
func Test_Whisper_Full_With_State(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer func() { _ = fh.Close() }()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
data := buf.AsFloat32Buffer().Data
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
state := ctx.Whisper_init_state()
assert.NotNil(state)
defer state.Whisper_free_state()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
// Run using state
err = ctx.Whisper_full_with_state(state, params, data, nil, nil, nil)
assert.NoError(err)
// Validate results are stored in state
nSegments := ctx.Whisper_full_n_segments_from_state(state)
assert.GreaterOrEqual(nSegments, 1)
text := ctx.Whisper_full_get_segment_text_from_state(state, 0)
assert.NotEmpty(text)
}
func Test_Whisper_Lang_Auto_Detect_With_State(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Open samples
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer func() { _ = fh.Close() }()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
assert.NoError(err)
data := buf.AsFloat32Buffer().Data
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
state := ctx.Whisper_init_state()
assert.NotNil(state)
defer state.Whisper_free_state()
threads := runtime.NumCPU()
// Prepare mel into state then detect
assert.NoError(ctx.Whisper_pcm_to_mel_with_state(state, data, threads))
probs, err := ctx.Whisper_lang_auto_detect_with_state(state, 0, threads)
assert.NoError(err)
assert.Equal(whisper.Whisper_lang_max_id()+1, len(probs))
}
func Test_Whisper_Concurrent_With_State(t *testing.T) {
assert := assert.New(t)
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
t.Skip("Skipping test, model not found:", ModelPath)
}
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
t.Skip("Skipping test, sample not found:", SamplePath)
}
// Load audio once
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer func() { _ = fh.Close() }()
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
data := buf.AsFloat32Buffer().Data
ctx := whisper.Whisper_init(ModelPath)
assert.NotNil(ctx)
defer ctx.Whisper_free()
// Each goroutine has its own state
state1 := ctx.Whisper_init_state()
state2 := ctx.Whisper_init_state()
assert.NotNil(state1)
assert.NotNil(state2)
defer state1.Whisper_free_state()
defer state2.Whisper_free_state()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
var wg sync.WaitGroup
var mu sync.Mutex // guard calls into shared ctx, per upstream note not thread-safe for same context
errs := make(chan error, 2)
worker := func(state *whisper.State) {
defer wg.Done()
mu.Lock()
err := ctx.Whisper_full_with_state(state, params, data, nil, nil, nil)
if err == nil {
n := ctx.Whisper_full_n_segments_from_state(state)
if n <= 0 {
err = errors.New("no segments")
} else {
_ = ctx.Whisper_full_get_segment_text_from_state(state, 0)
}
}
mu.Unlock()
errs <- err
}
wg.Add(2)
go worker(state1)
go worker(state2)
wg.Wait()
close(errs)
for e := range errs {
assert.NoError(e)
}
}