Merge a7a8169b99 into fc674574ca
This commit is contained in:
commit
d7698ab1fd
|
|
@ -3,6 +3,7 @@
|
|||
*.d
|
||||
.cache/
|
||||
.coreml/
|
||||
pkg/
|
||||
.test/
|
||||
.venv/
|
||||
.vs/
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
build
|
||||
models
|
||||
samples/a13.wav
|
||||
samples/benchmark_out.wav
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ whisper: mkdir
|
|||
-DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build ../../${BUILD_DIR} --target whisper
|
||||
|
||||
test: model-small whisper modtidy
|
||||
test: model-tiny whisper modtidy
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
|
||||
|
|
@ -46,8 +46,15 @@ endif
|
|||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
model-small: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-small.en.bin
|
||||
benchmark: model-tiny whisper modtidy
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -bench='BenchmarkContextProcessCPU$$' -benchtime=1x -benchmem -run '^$$' ./pkg/whisper/...
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -bench='BenchmarkContextProcessCPU$$' -benchtime=1x -benchmem -run '^$$' ./pkg/whisper/...
|
||||
endif
|
||||
|
||||
model-tiny: mkdir examples/go-model-download
|
||||
@${BUILD_DIR}/go-model-download -out models ggml-tiny.en.bin
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,12 @@ This package provides Go bindings for whisper.cpp. They have been tested on:
|
|||
* Fedora Linux on x86_64
|
||||
|
||||
The "low level" bindings are in the `bindings/go` directory and there is a more
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory. The most simple usage
|
||||
is as follows:
|
||||
Go-style package in the `bindings/go/pkg/whisper` directory.
|
||||
|
||||
Legacy stateless example (single worker). For the recommended stateful API and
|
||||
concurrency-safe usage, see "New high-level API" below. Note: `Model.NewContext()`
|
||||
returns a stateless context for backward compatibility and is not safe for parallel
|
||||
`Process` calls (may return `ErrStatelessBusy`).
|
||||
|
||||
```go
|
||||
import (
|
||||
|
|
@ -100,6 +104,123 @@ Getting help:
|
|||
|
||||
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312)
|
||||
|
||||
## New high-level API (stateful and stateless contexts)
|
||||
|
||||
The `pkg/whisper` package now exposes two context kinds:
|
||||
|
||||
- StatefulContext: recommended for concurrency. Each context owns its own whisper_state.
|
||||
- StatelessContext: shares the model context. Simpler, but not suitable for parallel `Process` calls.
|
||||
|
||||
### Quick start: stateful context (recommended)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load model
|
||||
model, err := whisper.NewModelContext("./models/ggml-small.en.bin")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Close()
|
||||
|
||||
// Configure parameters (optional: provide a config func)
|
||||
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) {
|
||||
p.SetThreads(4)
|
||||
p.SetLanguage("en") // or "auto"
|
||||
p.SetTranslate(false)
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create stateful context (safe for running in parallel goroutines)
|
||||
ctx, err := whisper.NewStatefulContext(model, params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer ctx.Close()
|
||||
|
||||
// Your 16-bit mono PCM at 16kHz as float32 samples
|
||||
var samples []float32
|
||||
|
||||
// Process. Callbacks are optional.
|
||||
if err := ctx.Process(samples, nil, nil, nil); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Read segments
|
||||
for {
|
||||
seg, err := ctx.NextSegment()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Printf("[%v -> %v] %s\n", seg.Start, seg.End, seg.Text)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Quick start: stateless context (single worker)
|
||||
|
||||
```go
|
||||
// Load model as above
|
||||
model, _ := whisper.NewModelContext("./models/ggml-small.en.bin")
|
||||
defer model.Close()
|
||||
|
||||
params, _ := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
|
||||
ctx, _ := whisper.NewStatelessContext(model, params)
|
||||
defer ctx.Close()
|
||||
|
||||
if err := ctx.Process(samples, nil, nil, nil); err != nil { panic(err) }
|
||||
for {
|
||||
seg, err := ctx.NextSegment()
|
||||
if err != nil { break }
|
||||
fmt.Println(seg.Text)
|
||||
}
|
||||
```
|
||||
|
||||
### Deprecations and migration notes
|
||||
|
||||
- The `Context` interface setters are deprecated (SetThreads, SetLanguage, etc.). Use `Parameters` via `NewParameters` and pass it when creating a context.
|
||||
- `Model.NewContext()` remains for backward compatibility and returns a stateless context by default. Prefer `NewStatefulContext` for concurrency.
|
||||
- Stateless contexts share the model context. A concurrency gate prevents overlapping `Process` calls and will return `ErrStatelessBusy` if another `Process` is in flight.
|
||||
- For parallel processing, create one `StatefulContext` per goroutine.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Benchmarks live in `pkg/whisper` and compare CPU vs GPU, stateful vs stateless, threads, and callback modes.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Model: `models/ggml-small.en.bin` (or your choice).
|
||||
- Sample: `samples/jfk.wav`.
|
||||
- Build the C libs once (also downloads a model for examples):
|
||||
|
||||
```bash
|
||||
cd bindings/go
|
||||
make examples
|
||||
# optionally: ./build/go-model-download -out models
|
||||
```
|
||||
|
||||
### Run benchmarks
|
||||
|
||||
```bash
|
||||
cd bindings/go/pkg/whisper
|
||||
make benchmark
|
||||
```
|
||||
|
||||
### What the benchmarks measure
|
||||
|
||||
- Variants: device (cpu/gpu) x context kind (stateless/stateful) x threads {1,2,4, NumCPU} x callback mode (NoCallback, WithSegmentCallback).
|
||||
- Standard Go benchmark outputs: ns/op, B/op, allocs/op. We also set bytes per op to sample bytes.
|
||||
- Custom metric `ms_process`: wall time per `Process` iteration, reported via `b.ReportMetric`.
|
||||
- When `printTimings` is enabled, model-level timings are printed for NoCallback runs using `model.PrintTimings()`.
|
||||
|
||||
## License
|
||||
|
||||
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
|
||||
|
|
|
|||
|
|
@ -18,9 +18,10 @@ import (
|
|||
// CONSTANTS
|
||||
|
||||
const (
|
||||
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models
|
||||
srcUrlTinydiarize = "https://huggingface.co/akashmjn/tinydiarize-whisper.cpp/resolve/main/"
|
||||
srcExt = ".bin" // Filename extension
|
||||
bufSize = 1024 * 64 // Size of the buffer used for downloading the model
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -38,6 +39,7 @@ var (
|
|||
"large-v2", "large-v2-q5_0", "large-v2-q8_0",
|
||||
"large-v3", "large-v3-q5_0",
|
||||
"large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0",
|
||||
"small.en-tdrz",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -219,6 +221,12 @@ func URLForModel(model string) (string, error) {
|
|||
model += srcExt
|
||||
}
|
||||
|
||||
srcUrl := srcUrl
|
||||
|
||||
if strings.Contains(model, "tdrz") {
|
||||
srcUrl = srcUrlTinydiarize
|
||||
}
|
||||
|
||||
// Parse the base URL
|
||||
url, err := url.Parse(srcUrl)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ module github.com/ggerganov/whisper.cpp/bindings/go
|
|||
go 1.23
|
||||
|
||||
require (
|
||||
github.com/go-audio/audio v1.0.0
|
||||
github.com/go-audio/wav v1.1.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-audio/audio v1.0.0 // indirect
|
||||
github.com/go-audio/riff v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
|
|||
|
|
@ -47,6 +47,16 @@ func (p *Params) SetPrintTimestamps(v bool) {
|
|||
p.print_timestamps = toBool(v)
|
||||
}
|
||||
|
||||
// Enable extra debug information
|
||||
func (p *Params) SetDebugMode(v bool) {
|
||||
p.debug_mode = toBool(v)
|
||||
}
|
||||
|
||||
// Enable tinydiarize speaker turn detection
|
||||
func (p *Params) SetDiarize(v bool) {
|
||||
p.tdrz_enable = toBool(v)
|
||||
}
|
||||
|
||||
// Voice Activity Detection (VAD)
|
||||
func (p *Params) SetVAD(v bool) {
|
||||
p.vad = toBool(v)
|
||||
|
|
@ -179,6 +189,8 @@ func (p *Params) SetInitialPrompt(prompt string) {
|
|||
p.initial_prompt = C.CString(prompt)
|
||||
}
|
||||
|
||||
// SetCarryInitialPrompt if true, always prepend initial_prompt to every decode window
|
||||
// (may reduce conditioning on previous text)
|
||||
func (p *Params) SetCarryInitialPrompt(v bool) {
|
||||
p.carry_initial_prompt = toBool(v)
|
||||
}
|
||||
|
|
@ -236,9 +248,6 @@ func (p *Params) String() string {
|
|||
if p.token_timestamps {
|
||||
str += " token_timestamps"
|
||||
}
|
||||
if p.carry_initial_prompt {
|
||||
str += " carry_initial_prompt"
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
// Gate provides a simple acquire/release contract per key.
|
||||
// The default implementation is a single-entry lock per key (limit=1).
|
||||
type Gate interface {
|
||||
// Acquire returns true if the key was acquired; false if already held
|
||||
Acquire(key any) bool
|
||||
// Release releases the key if currently held
|
||||
Release(key any)
|
||||
}
|
||||
|
||||
// singleFlightGate is a minimal lock with limit=1 per key
|
||||
type singleFlightGate struct {
|
||||
m sync.Map // key -> *int32 (0 available, 1 held)
|
||||
}
|
||||
|
||||
func (g *singleFlightGate) Acquire(key any) bool {
|
||||
ptr, _ := g.m.LoadOrStore(key, new(int32))
|
||||
busy := ptr.(*int32)
|
||||
return atomic.CompareAndSwapInt32(busy, 0, 1)
|
||||
}
|
||||
|
||||
func (g *singleFlightGate) Release(key any) {
|
||||
if v, ok := g.m.Load(key); ok {
|
||||
atomic.StoreInt32(v.(*int32), 0)
|
||||
}
|
||||
}
|
||||
|
||||
var defaultGate Gate = &singleFlightGate{}
|
||||
|
||||
// SetGate allows applications to override the default gate (e.g., for custom policies)
|
||||
// Passing nil resets to the default singleFlightGate.
|
||||
func SetGate(g Gate) {
|
||||
if g == nil {
|
||||
defaultGate = &singleFlightGate{}
|
||||
return
|
||||
}
|
||||
defaultGate = g
|
||||
}
|
||||
|
||||
func gate() Gate { return defaultGate }
|
||||
|
||||
// modelKey derives a stable key per underlying model context for guarding stateless ops
|
||||
func modelKey(model *ModelContext) *whisper.Context {
|
||||
if model == nil || model.ctxAccessor() == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, _ := model.ctxAccessor().context()
|
||||
return ctx
|
||||
}
|
||||
|
|
@ -11,11 +11,21 @@ import (
|
|||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
|
||||
// Deprecated: Use ErrModelClosed instead for checking the model is closed error
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrModelNotMultilingual = errors.New("model is not multilingual")
|
||||
ErrModelClosed = errors.Join(errors.New("model has been closed"), ErrInternalAppError)
|
||||
ErrStatelessBusy = errors.New("stateless context is busy; concurrent processing not supported")
|
||||
|
||||
// Private errors
|
||||
errParametersRequired = errors.New("parameters are required")
|
||||
errModelRequired = errors.New("model is required")
|
||||
errUnableToCreateState = errors.New("unable to create state")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -26,3 +36,10 @@ const SampleRate = whisper.SampleRate
|
|||
|
||||
// SampleBits is the number of bytes per sample.
|
||||
const SampleBits = whisper.SampleBits
|
||||
|
||||
type SamplingStrategy uint32
|
||||
|
||||
const (
|
||||
SAMPLING_GREEDY SamplingStrategy = SamplingStrategy(whisper.SAMPLING_GREEDY)
|
||||
SAMPLING_BEAM_SEARCH SamplingStrategy = SamplingStrategy(whisper.SAMPLING_BEAM_SEARCH)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,385 +0,0 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type context struct {
|
||||
n int
|
||||
model *model
|
||||
params whisper.Params
|
||||
}
|
||||
|
||||
// Make sure context adheres to the interface
|
||||
var _ Context = (*context)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func newContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
|
||||
// Return success
|
||||
return context, nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Set the language to use for speech recognition.
|
||||
func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
|
||||
if lang == "auto" {
|
||||
context.params.SetLanguage(-1)
|
||||
} else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
return err
|
||||
}
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *context) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
id := context.params.Language()
|
||||
if id == -1 {
|
||||
return "auto"
|
||||
}
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
func (context *context) DetectedLanguage() string {
|
||||
return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id())
|
||||
}
|
||||
|
||||
// Set translate flag
|
||||
func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Voice Activity Detection (VAD)
|
||||
func (context *context) SetVAD(v bool) {
|
||||
context.params.SetVAD(v)
|
||||
}
|
||||
|
||||
func (context *context) SetVADModelPath(path string) {
|
||||
context.params.SetVADModelPath(path)
|
||||
}
|
||||
|
||||
func (context *context) SetVADThreshold(t float32) {
|
||||
context.params.SetVADThreshold(t)
|
||||
}
|
||||
|
||||
func (context *context) SetVADMinSpeechMs(ms int) {
|
||||
context.params.SetVADMinSpeechMs(ms)
|
||||
}
|
||||
|
||||
func (context *context) SetVADMinSilenceMs(ms int) {
|
||||
context.params.SetVADMinSilenceMs(ms)
|
||||
}
|
||||
|
||||
func (context *context) SetVADMaxSpeechSec(s float32) {
|
||||
context.params.SetVADMaxSpeechSec(s)
|
||||
}
|
||||
|
||||
func (context *context) SetVADSpeechPadMs(ms int) {
|
||||
context.params.SetVADSpeechPadMs(ms)
|
||||
}
|
||||
|
||||
func (context *context) SetVADSamplesOverlap(sec float32) {
|
||||
context.params.SetVADSamplesOverlap(sec)
|
||||
}
|
||||
|
||||
func (context *context) SetSplitOnWord(v bool) {
|
||||
context.params.SetSplitOnWord(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (context *context) SetThreads(v uint) {
|
||||
context.params.SetThreads(int(v))
|
||||
}
|
||||
|
||||
// Set time offset
|
||||
func (context *context) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set duration of audio to process
|
||||
func (context *context) SetDuration(v time.Duration) {
|
||||
context.params.SetDuration(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (context *context) SetTokenThreshold(t float32) {
|
||||
context.params.SetTokenThreshold(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (context *context) SetTokenSumThreshold(t float32) {
|
||||
context.params.SetTokenSumThreshold(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set token timestamps flag
|
||||
func (context *context) SetTokenTimestamps(b bool) {
|
||||
context.params.SetTokenTimestamps(b)
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// Set audio encoder context
|
||||
func (context *context) SetAudioCtx(n uint) {
|
||||
context.params.SetAudioCtx(int(n))
|
||||
}
|
||||
|
||||
// Set maximum number of text context tokens to store
|
||||
func (context *context) SetMaxContext(n int) {
|
||||
context.params.SetMaxContext(n)
|
||||
}
|
||||
|
||||
// Set Beam Size
|
||||
func (context *context) SetBeamSize(n int) {
|
||||
context.params.SetBeamSize(n)
|
||||
}
|
||||
|
||||
// Set Entropy threshold
|
||||
func (context *context) SetEntropyThold(t float32) {
|
||||
context.params.SetEntropyThold(t)
|
||||
}
|
||||
|
||||
// Set Temperature
|
||||
func (context *context) SetTemperature(t float32) {
|
||||
context.params.SetTemperature(t)
|
||||
}
|
||||
|
||||
// Set the fallback temperature incrementation
|
||||
// Pass -1.0 to disable this feature
|
||||
func (context *context) SetTemperatureFallback(t float32) {
|
||||
context.params.SetTemperatureFallback(t)
|
||||
}
|
||||
|
||||
// Set initial prompt
|
||||
func (context *context) SetInitialPrompt(prompt string) {
|
||||
context.params.SetInitialPrompt(prompt)
|
||||
}
|
||||
|
||||
// ResetTimings resets the mode timings. Should be called before processing
|
||||
func (context *context) ResetTimings() {
|
||||
context.model.ctx.Whisper_reset_timings()
|
||||
}
|
||||
|
||||
// PrintTimings prints the model timings to stdout.
|
||||
func (context *context) PrintTimings() {
|
||||
context.model.ctx.Whisper_print_timings()
|
||||
}
|
||||
|
||||
// SystemInfo returns the system information
|
||||
func (context *context) SystemInfo() string {
|
||||
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
|
||||
context.params.Threads(),
|
||||
runtime.NumCPU(),
|
||||
whisper.Whisper_print_system_info(),
|
||||
)
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages.
|
||||
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
|
||||
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return langProbs, nil
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(
|
||||
data []float32,
|
||||
callEncoderBegin EncoderBeginCallback,
|
||||
callNewSegment SegmentCallback,
|
||||
callProgress ProgressCallback,
|
||||
) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if callNewSegment != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
// We don't do parallel processing at the moment
|
||||
processors := 0
|
||||
if processors > 1 {
|
||||
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
|
||||
func(new int) {
|
||||
if callNewSegment != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
callNewSegment(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
|
||||
func(new int) {
|
||||
if callNewSegment != nil {
|
||||
num_segments := context.model.ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
callNewSegment(toSegment(context.model.ctx, i))
|
||||
}
|
||||
}
|
||||
}, func(progress int) {
|
||||
if callProgress != nil {
|
||||
callProgress(progress)
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset n so that more Segments can be available within NextSegment call
|
||||
context.n = 0
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the next segment of tokens
|
||||
func (context *context) NextSegment() (Segment, error) {
|
||||
if context.model.ctx == nil {
|
||||
return Segment{}, ErrInternalAppError
|
||||
}
|
||||
if context.n >= context.model.ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
// Populate result
|
||||
result := toSegment(context.model.ctx, context.n)
|
||||
|
||||
// Increment the cursor
|
||||
context.n++
|
||||
|
||||
// Return success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Test for text tokens
|
||||
func (context *context) IsText(t Token) bool {
|
||||
switch {
|
||||
case context.IsBEG(t):
|
||||
return false
|
||||
case context.IsSOT(t):
|
||||
return false
|
||||
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
||||
return false
|
||||
case context.IsPREV(t):
|
||||
return false
|
||||
case context.IsSOLM(t):
|
||||
return false
|
||||
case context.IsNOT(t):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Test for "begin" token
|
||||
func (context *context) IsBEG(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
||||
}
|
||||
|
||||
// Test for "start of transcription" token
|
||||
func (context *context) IsSOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
||||
}
|
||||
|
||||
// Test for "end of transcription" token
|
||||
func (context *context) IsEOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
||||
}
|
||||
|
||||
// Test for "start of prev" token
|
||||
func (context *context) IsPREV(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
||||
}
|
||||
|
||||
// Test for "start of lm" token
|
||||
func (context *context) IsSOLM(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
||||
}
|
||||
|
||||
// Test for "No timestamps" token
|
||||
func (context *context) IsNOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
||||
}
|
||||
|
||||
// Test for token associated with a specific language
|
||||
func (context *context) IsLANG(t Token, lang string) bool {
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func toSegment(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokens(ctx, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokens(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
for i := 0; i < len(result); i++ {
|
||||
data := ctx.Whisper_full_get_token_data(n, i)
|
||||
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: ctx.Whisper_full_get_token_text(n, i),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
Start: time.Duration(data.T0()) * time.Millisecond * 10,
|
||||
End: time.Duration(data.T1()) * time.Millisecond * 10,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/audio"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func processAndExtractSegmentsSequentially(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) {
|
||||
if err := ctx.Process(samples, nil, nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var segments []whisper.Segment
|
||||
for {
|
||||
seg, err := ctx.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segments = append(segments, seg)
|
||||
}
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
func processAndExtractSegmentsWithCallback(ctx whisper.Context, samples []float32) ([]whisper.Segment, error) {
|
||||
segments := make([]whisper.Segment, 0)
|
||||
|
||||
if err := ctx.Process(samples, nil, func(seg whisper.Segment) {
|
||||
segments = append(segments, seg)
|
||||
}, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return segments, nil
|
||||
}
|
||||
|
||||
// benchProcessVariants runs the common benchmark matrix across context kinds,
|
||||
// thread sets, and callback modes, for given samples. If singleIteration is true
|
||||
// it runs only one iteration regardless of b.N. If printTimings is true,
|
||||
// model timings and custom ms_process metric are reported for NoCallback runs.
|
||||
func benchProcessVariants(
|
||||
b *testing.B,
|
||||
samples []float32,
|
||||
singleIteration bool,
|
||||
printTimings bool,
|
||||
useGPU bool,
|
||||
) {
|
||||
threadSets := []uint{1, 2, 4, uint(runtime.NumCPU())}
|
||||
|
||||
device := "cpu"
|
||||
if useGPU {
|
||||
device = "gpu"
|
||||
}
|
||||
|
||||
// Initialize model per device mode
|
||||
mp := whisper.NewModelContextParams()
|
||||
mp.SetUseGPU(useGPU)
|
||||
model, err := whisper.NewModelContextWithParams(ModelPath, mp)
|
||||
if err != nil {
|
||||
b.Fatalf("load model (%s): %v", device, err)
|
||||
}
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
// Context kinds: stateless and stateful
|
||||
ctxKinds := []struct {
|
||||
name string
|
||||
new func() (whisper.Context, error)
|
||||
}{
|
||||
{
|
||||
name: "stateless",
|
||||
new: func() (whisper.Context, error) {
|
||||
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, func(p *whisper.Parameters) {})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return whisper.NewStatelessContext(model, params)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stateful",
|
||||
new: func() (whisper.Context, error) {
|
||||
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return whisper.NewStatefulContext(model, params)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, kind := range ctxKinds {
|
||||
b.Run(device+"/"+kind.name, func(b *testing.B) {
|
||||
for _, threads := range threadSets {
|
||||
b.Run(fmt.Sprintf("threads=%d/NoCallback", threads), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(samples) * 4))
|
||||
ctx, err := kind.new()
|
||||
if err != nil {
|
||||
b.Fatalf("new %s context: %v", kind.name, err)
|
||||
}
|
||||
defer func() { _ = ctx.Close() }()
|
||||
ctx.SetThreads(threads)
|
||||
|
||||
iters := b.N
|
||||
if singleIteration {
|
||||
iters = 1
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < iters; i++ {
|
||||
model.ResetTimings()
|
||||
start := time.Now()
|
||||
|
||||
segments, err := processAndExtractSegmentsSequentially(ctx, samples)
|
||||
if err != nil {
|
||||
b.Fatalf("process and extract segments sequentially: %v", err)
|
||||
}
|
||||
|
||||
b.Logf("segments: %+v", segments)
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if printTimings {
|
||||
model.PrintTimings()
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("threads=%d/WithSegmentCallback", threads), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(samples) * 4))
|
||||
ctx, err := kind.new()
|
||||
if err != nil {
|
||||
b.Fatalf("new %s context: %v", kind.name, err)
|
||||
}
|
||||
defer func() { _ = ctx.Close() }()
|
||||
ctx.SetThreads(threads)
|
||||
|
||||
iters := b.N
|
||||
if singleIteration {
|
||||
iters = 1
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < iters; i++ {
|
||||
start := time.Now()
|
||||
model.ResetTimings()
|
||||
|
||||
// Passing a segment callback forces single-segment mode and exercises token extraction
|
||||
segments, err := processAndExtractSegmentsWithCallback(ctx, samples)
|
||||
if err != nil {
|
||||
b.Fatalf("process with callback: %v", err)
|
||||
}
|
||||
|
||||
b.Logf("segments: %+v", segments)
|
||||
|
||||
elapsed := time.Since(start)
|
||||
if printTimings {
|
||||
model.PrintTimings()
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(elapsed.Milliseconds()), "ms_process")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkContextProcess runs the high-level Context.Process across
|
||||
// different thread counts, with and without segment callbacks.
|
||||
func BenchmarkContextProcessCPU(b *testing.B) {
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
b.Skipf("model not found: %s", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
b.Skipf("sample not found: %s", SamplePath)
|
||||
}
|
||||
|
||||
// Load audio once (reuse helper)
|
||||
data := helperLoadSample(b, SamplePath)
|
||||
|
||||
benchProcessVariants(b, data, false, true, false)
|
||||
}
|
||||
|
||||
// BenchmarkContextProcessBig runs one single iteration over a big input
|
||||
// (the short sample concatenated 10x) to simulate long audio processing.
|
||||
// This is complementary to BenchmarkContextProcess which runs many iterations
|
||||
// over the short sample.
|
||||
func BenchmarkContextProcessBigCPU(b *testing.B) {
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
b.Skipf("model not found: %s", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
b.Skipf("sample not found: %s", SamplePath)
|
||||
}
|
||||
|
||||
// Load audio once (reuse helper with meta)
|
||||
data, sampleRate, numChans := helperLoadSampleWithMeta(b, SamplePath)
|
||||
|
||||
// Build big dataset: input concatenated 10x
|
||||
bigData := make([]float32, len(data)*10)
|
||||
for i := 0; i < 10; i++ {
|
||||
copy(bigData[i*len(data):(i+1)*len(data)], data)
|
||||
}
|
||||
|
||||
// Write the big dataset to a wav file for inspection
|
||||
outPath := "../../samples/benchmark_out.wav"
|
||||
fout, err := os.Create(outPath)
|
||||
if err != nil {
|
||||
b.Fatalf("create output wav: %v", err)
|
||||
}
|
||||
enc := wav.NewEncoder(fout, sampleRate, 16, numChans, 1)
|
||||
intBuf := &audio.IntBuffer{
|
||||
Format: &audio.Format{NumChannels: numChans, SampleRate: sampleRate},
|
||||
SourceBitDepth: 16,
|
||||
Data: make([]int, len(bigData)),
|
||||
}
|
||||
for i, s := range bigData {
|
||||
v := int(math.Round(float64(s) * 32767.0))
|
||||
if v > 32767 {
|
||||
v = 32767
|
||||
} else if v < -32768 {
|
||||
v = -32768
|
||||
}
|
||||
intBuf.Data[i] = v
|
||||
}
|
||||
if err := enc.Write(intBuf); err != nil {
|
||||
_ = fout.Close()
|
||||
b.Fatalf("encode wav: %v", err)
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
_ = fout.Close()
|
||||
b.Fatalf("close encoder: %v", err)
|
||||
}
|
||||
_ = fout.Close()
|
||||
|
||||
benchProcessVariants(b, bigData, true, true, false)
|
||||
}
|
||||
|
||||
// GPU variants reuse model-level GPU enablement via model params
|
||||
func BenchmarkContextProcessGPU(b *testing.B) {
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
b.Skipf("model not found: %s", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
b.Skipf("sample not found: %s", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(b, SamplePath)
|
||||
|
||||
benchProcessVariants(b, data, false, true, true)
|
||||
}
|
||||
|
||||
func BenchmarkContextProcessBigGPU(b *testing.B) {
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
b.Skipf("model not found: %s", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
b.Skipf("sample not found: %s", SamplePath)
|
||||
}
|
||||
|
||||
data, _, _ := helperLoadSampleWithMeta(b, SamplePath)
|
||||
|
||||
bigData := make([]float32, len(data)*10)
|
||||
for i := 0; i < 10; i++ {
|
||||
copy(bigData[i*len(data):(i+1)*len(data)], data)
|
||||
}
|
||||
|
||||
benchProcessVariants(b, bigData, true, true, true)
|
||||
}
|
||||
|
|
@ -1,124 +1,526 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
"github.com/go-audio/wav"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetLanguage(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
|
||||
// This returns an error since
|
||||
// the model 'models/ggml-small.en.bin'
|
||||
// that is loaded is not multilingual
|
||||
err = context.SetLanguage("en")
|
||||
assert.Error(err)
|
||||
// This returns an error since the small.en model is not multilingual
|
||||
err := ctx.SetLanguage("en")
|
||||
assert.Error(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextModelIsMultilingual(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
|
||||
isMultilingual := context.IsMultilingual()
|
||||
|
||||
// This returns false since
|
||||
// the model 'models/ggml-small.en.bin'
|
||||
// that is loaded is not multilingual
|
||||
assert.False(isMultilingual)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
assert.False(ctx.IsMultilingual())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLanguage(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
expectedLanguage := "en"
|
||||
actualLanguage := ctx.Language()
|
||||
assert.Equal(expectedLanguage, actualLanguage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// This always returns en since
|
||||
// the model 'models/ggml-small.en.bin'
|
||||
// that is loaded is not multilingual
|
||||
expectedLanguage := "en"
|
||||
actualLanguage := context.Language()
|
||||
assert.Equal(expectedLanguage, actualLanguage)
|
||||
// Generic behavior: Language() and DetectedLanguage() match for both context types
|
||||
func TestContext_Generic_LanguageAndDetectedLanguage(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
|
||||
langBefore := ctx.Language()
|
||||
assert.NoError(ctx.Process(data, nil, nil, nil))
|
||||
detected := ctx.DetectedLanguage()
|
||||
assert.Equal(langBefore, detected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcess(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
assert.Equal(uint16(1), dec.NumChans)
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
|
||||
err = context.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
err := ctx.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectedLanguage(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
assert.Equal(uint16(1), dec.NumChans)
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
|
||||
err = context.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
expectedLanguage := "en"
|
||||
actualLanguage := context.DetectedLanguage()
|
||||
assert.Equal(expectedLanguage, actualLanguage)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
err := ctx.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
expectedLanguage := "en"
|
||||
actualLanguage := ctx.DetectedLanguage()
|
||||
assert.Equal(expectedLanguage, actualLanguage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContext_ConcurrentProcessing tests that multiple contexts can process concurrently
|
||||
// without interfering with each other (validates the whisper_state isolation fix)
|
||||
func TestContext_ConcurrentProcessing(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
|
||||
err := ctx.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
seg, err := ctx.NextSegment()
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(seg.Text)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContext_Close tests that Context.Close() properly frees resources
|
||||
// and allows context to be used even after it has been closed
|
||||
func TestContext_Close(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
|
||||
// Close the context
|
||||
err := ctx.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to use closed context - should return errors
|
||||
err = ctx.Process([]float32{0.1, 0.2, 0.3}, nil, nil, nil)
|
||||
require.ErrorIs(t, err, whisper.ErrModelClosed)
|
||||
// TODO: remove this logic after deprecating the ErrInternalAppError
|
||||
require.ErrorIs(t, err, whisper.ErrInternalAppError)
|
||||
|
||||
lang := ctx.DetectedLanguage()
|
||||
require.Empty(t, lang)
|
||||
|
||||
_, err = ctx.NextSegment()
|
||||
assert.ErrorIs(err, whisper.ErrModelClosed)
|
||||
// TODO: remove this logic after deprecating the ErrInternalAppError
|
||||
assert.ErrorIs(err, whisper.ErrInternalAppError)
|
||||
|
||||
// Multiple closes should be safe
|
||||
err = ctx.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Close_Context_of_Closed_Model(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
|
||||
t.Run("stateless", func(t *testing.T) {
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = model.Close() }()
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatelessContext(model, params)
|
||||
assert.NoError(err)
|
||||
require.NoError(t, model.Close())
|
||||
require.NoError(t, ctx.Close())
|
||||
})
|
||||
|
||||
t.Run("stateful", func(t *testing.T) {
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = model.Close() }()
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatefulContext(model, params)
|
||||
assert.NoError(err)
|
||||
require.NoError(t, model.Close())
|
||||
require.NoError(t, ctx.Close())
|
||||
})
|
||||
}
|
||||
|
||||
func TestContext_VAD_And_Diarization_Params_DoNotPanic(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, nil)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(params)
|
||||
|
||||
ctx, err := whisper.NewStatefulContext(model, params)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = ctx.Close() }()
|
||||
|
||||
p := ctx.Params()
|
||||
p.SetDiarize(true)
|
||||
p.SetVAD(true)
|
||||
p.SetVADThreshold(0.5)
|
||||
p.SetVADMinSpeechMs(200)
|
||||
p.SetVADMinSilenceMs(100)
|
||||
p.SetVADMaxSpeechSec(10)
|
||||
p.SetVADSpeechPadMs(30)
|
||||
p.SetVADSamplesOverlap(0.02)
|
||||
|
||||
err = ctx.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
func TestContext_SpeakerTurnNext_Field_Present(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
new func(t *testing.T) (whisper.Context, func())
|
||||
}{
|
||||
{name: "stateless", new: helperNewStatelessContext},
|
||||
{name: "stateful", new: helperNewStatefulContext},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cleanup := tc.new(t)
|
||||
defer cleanup()
|
||||
|
||||
err := ctx.Process(data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
seg, err := ctx.NextSegment()
|
||||
assert.NoError(err)
|
||||
t.Logf("SpeakerTurnNext: %v", seg.SpeakerTurnNext)
|
||||
_ = seg.SpeakerTurnNext
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Process produces at least one segment for both stateless and stateful contexts
|
||||
func TestContext_Process_ProducesSegments_BothKinds(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
// Stateless
|
||||
stateless, cleanupS := helperNewStatelessContext(t)
|
||||
defer cleanupS()
|
||||
require.NoError(t, stateless.Process(data, nil, nil, nil))
|
||||
var statelessCount int
|
||||
for {
|
||||
_, err := stateless.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
statelessCount++
|
||||
}
|
||||
assert.Greater(statelessCount, 0, "stateless should produce at least one segment")
|
||||
|
||||
// Stateful
|
||||
stateful, cleanupSt := helperNewStatefulContext(t)
|
||||
defer cleanupSt()
|
||||
require.NoError(t, stateful.Process(data, nil, nil, nil))
|
||||
var statefulCount int
|
||||
for {
|
||||
_, err := stateful.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
statefulCount++
|
||||
}
|
||||
assert.Greater(statefulCount, 0, "stateful should produce at least one segment")
|
||||
}
|
||||
|
||||
// With temperature=0 (greedy), stateless and stateful should produce identical segments
|
||||
func TestContext_Process_SameResults_TemperatureZero(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
// Use a single model to avoid environment differences
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
// Independent params with temperature=0 for determinism
|
||||
p := helperNewParams(t, model, func(p *whisper.Parameters) {
|
||||
p.SetTemperature(0)
|
||||
p.SetThreads(1)
|
||||
})
|
||||
|
||||
stateless, err := whisper.NewStatelessContext(model, p)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = stateless.Close() }()
|
||||
|
||||
stateful, err := whisper.NewStatefulContext(model, p)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = stateful.Close() }()
|
||||
|
||||
require.NoError(t, stateless.Process(data, nil, nil, nil))
|
||||
require.NoError(t, stateful.Process(data, nil, nil, nil))
|
||||
|
||||
// Collect segment texts
|
||||
var segsStateless, segsStateful []string
|
||||
for {
|
||||
seg, err := stateless.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
segsStateless = append(segsStateless, seg.Text)
|
||||
}
|
||||
for {
|
||||
seg, err := stateful.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
segsStateful = append(segsStateful, seg.Text)
|
||||
}
|
||||
|
||||
// Both should have at least one segment and be identical
|
||||
require.Greater(t, len(segsStateless), 0)
|
||||
require.Greater(t, len(segsStateful), 0)
|
||||
assert.Equal(len(segsStateful), len(segsStateless))
|
||||
for i := range segsStateless {
|
||||
assert.Equal(segsStateless[i], segsStateful[i], "segment %d text differs", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Model.GetTimings: stateless processing updates model timings (non-zero),
|
||||
// stateful processing does not (zero timings)
|
||||
func TestModel_GetTimings_Stateless_NonZero_Stateful_Zero(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
// Stateless should produce non-zero timings
|
||||
t.Run("stateless", func(t *testing.T) {
|
||||
model.ResetTimings()
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatelessContext(model, params)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ctx.Close() }()
|
||||
|
||||
require.NoError(t, ctx.Process(data, nil, nil, nil))
|
||||
|
||||
timings, ok := model.GetTimings()
|
||||
require.True(t, ok, "expected timings to be available after stateless processing")
|
||||
nonZero := timings.SampleMS > 0 || timings.EncodeMS > 0 || timings.DecodeMS > 0 || timings.BatchdMS > 0 || timings.PromptMS > 0
|
||||
assert.True(nonZero, "expected at least one non-zero timing after stateless processing: %#v", timings)
|
||||
})
|
||||
|
||||
// Stateful should keep model-level timings at zero
|
||||
t.Run("stateful", func(t *testing.T) {
|
||||
model.ResetTimings()
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatefulContext(model, params)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ctx.Close() }()
|
||||
|
||||
require.NoError(t, ctx.Process(data, nil, nil, nil))
|
||||
|
||||
timings, ok := model.GetTimings()
|
||||
// Expect timings present but all zero; if not present at all, treat as zero-equivalent
|
||||
if ok {
|
||||
assert.Equal(float32(0), timings.SampleMS)
|
||||
assert.Equal(float32(0), timings.EncodeMS)
|
||||
assert.Equal(float32(0), timings.DecodeMS)
|
||||
assert.Equal(float32(0), timings.BatchdMS)
|
||||
assert.Equal(float32(0), timings.PromptMS)
|
||||
} else {
|
||||
t.Log("timings not available for stateful processing; treating as zero")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -20,15 +21,21 @@ type ProgressCallback func(int)
|
|||
// continue processing. It is called during the Process function
|
||||
type EncoderBeginCallback func() bool
|
||||
|
||||
type ParamsConfigure func(*Parameters)
|
||||
|
||||
// Model is the interface to a whisper model. Create a new model with the
|
||||
// function whisper.New(string)
|
||||
// Deprecated: Use NewModel implementation struct instead of relying on this interface
|
||||
type Model interface {
|
||||
io.Closer
|
||||
|
||||
// Return a new speech-to-text context.
|
||||
// It may return an error is the model is not loaded or closed
|
||||
// Deprecated: Use NewContext implementation struct instead of relying on this interface
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return true if the model is multilingual.
|
||||
// It returns false if the model is not loaded or closed
|
||||
IsMultilingual() bool
|
||||
|
||||
// Return all languages supported.
|
||||
|
|
@ -36,38 +43,91 @@ type Model interface {
|
|||
}
|
||||
|
||||
// Context is the speech recognition context.
|
||||
// Deprecated: Use NewContext implementation struct instead of relying on this interface
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
DetectedLanguage() string // Get detected language
|
||||
io.Closer
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSplitOnWord(bool) // Set split on word flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
SetAudioCtx(uint) // Set audio encoder context
|
||||
SetMaxContext(n int) // Set maximum number of text context tokens to store
|
||||
SetBeamSize(n int) // Set Beam Size
|
||||
SetEntropyThold(t float32) // Set Entropy threshold
|
||||
SetInitialPrompt(prompt string) // Set initial prompt
|
||||
SetTemperature(t float32) // Set temperature
|
||||
SetTemperatureFallback(t float32) // Set temperature incrementation
|
||||
// Deprecated: Use Params().SetLanguage() instead
|
||||
SetLanguage(string) error
|
||||
|
||||
SetVAD(v bool)
|
||||
SetVADModelPath(path string)
|
||||
SetVADThreshold(t float32)
|
||||
SetVADMinSpeechMs(ms int)
|
||||
SetVADMinSilenceMs(ms int)
|
||||
SetVADMaxSpeechSec(s float32)
|
||||
SetVADSpeechPadMs(ms int)
|
||||
SetVADSamplesOverlap(sec float32)
|
||||
// Deprecated: Use Params().SetTranslate() instead
|
||||
SetTranslate(bool)
|
||||
|
||||
// Deprecated: Use Params().SetSplitOnWord() instead
|
||||
SetSplitOnWord(bool)
|
||||
|
||||
// Deprecated: Use Params().SetThreads() instead
|
||||
SetThreads(uint)
|
||||
|
||||
// Deprecated: Use Params().SetOffset() instead
|
||||
SetOffset(time.Duration)
|
||||
|
||||
// Deprecated: Use Params().SetDuration() instead
|
||||
SetDuration(time.Duration)
|
||||
|
||||
// Deprecated: Use Params().SetTokenThreshold() instead
|
||||
SetTokenThreshold(float32)
|
||||
|
||||
// Deprecated: Use Params().SetTokenSumThreshold() instead
|
||||
SetTokenSumThreshold(float32)
|
||||
// Deprecated: Use Params().SetMaxSegmentLength() instead
|
||||
|
||||
SetMaxSegmentLength(uint)
|
||||
|
||||
// Deprecated: Use Params().SetTokenTimestamps() instead
|
||||
SetTokenTimestamps(bool)
|
||||
|
||||
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
|
||||
SetMaxTokensPerSegment(uint)
|
||||
|
||||
// Deprecated: Use Params().SetAudioCtx() instead
|
||||
SetAudioCtx(uint)
|
||||
|
||||
// Deprecated: Use Params().SetMaxContext() instead
|
||||
SetMaxContext(int)
|
||||
|
||||
// Deprecated: Use Params().SetBeamSize() instead
|
||||
SetBeamSize(int)
|
||||
|
||||
// Deprecated: Use Params().SetEntropyThold() instead
|
||||
SetEntropyThold(float32)
|
||||
|
||||
// Deprecated: Use Params().SetTemperature() instead
|
||||
SetTemperature(float32)
|
||||
|
||||
// Deprecated: Use Params().SetTemperatureFallback() instead
|
||||
SetTemperatureFallback(float32)
|
||||
|
||||
// Deprecated: Use Params().SetInitialPrompt() instead
|
||||
SetInitialPrompt(string)
|
||||
|
||||
// Get language of the context parameters
|
||||
// Deprecated: Use Params().Language() instead
|
||||
Language() string
|
||||
|
||||
// Deprecated: Use Model().IsMultilingual() instead
|
||||
IsMultilingual() bool
|
||||
|
||||
// Get detected language
|
||||
DetectedLanguage() string
|
||||
|
||||
// Voice Activity Detection (VAD) methods
|
||||
// Deprecated: Use Params().SetVAD() instead
|
||||
SetVAD(bool)
|
||||
// Deprecated: Use Params().SetVADModelPath() instead
|
||||
SetVADModelPath(string)
|
||||
// Deprecated: Use Params().SetVADThreshold() instead
|
||||
SetVADThreshold(float32)
|
||||
// Deprecated: Use Params().SetVADMinSpeechMs() instead
|
||||
SetVADMinSpeechMs(int)
|
||||
// Deprecated: Use Params().SetVADMinSilenceMs() instead
|
||||
SetVADMinSilenceMs(int)
|
||||
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
|
||||
SetVADMaxSpeechSec(float32)
|
||||
// Deprecated: Use Params().SetVADSpeechPadMs() instead
|
||||
SetVADSpeechPadMs(int)
|
||||
// Deprecated: Use Params().SetVADSamplesOverlap() instead
|
||||
SetVADSamplesOverlap(float32)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
|
|
@ -78,19 +138,39 @@ type Context interface {
|
|||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
|
||||
IsBEG(Token) bool // Test for "begin" token
|
||||
IsSOT(Token) bool // Test for "start of transcription" token
|
||||
IsEOT(Token) bool // Test for "end of transcription" token
|
||||
IsPREV(Token) bool // Test for "start of prev" token
|
||||
IsSOLM(Token) bool // Test for "start of lm" token
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
// Deprecated: Use Model().TokenIdentifier().IsBEG() instead
|
||||
IsBEG(Token) bool
|
||||
|
||||
// Timings
|
||||
// Deprecated: Use Model().TokenIdentifier().IsSOT() instead
|
||||
IsSOT(Token) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsEOT() instead
|
||||
IsEOT(Token) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsPREV() instead
|
||||
IsPREV(Token) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsSOLM() instead
|
||||
IsSOLM(Token) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsNOT() instead
|
||||
IsNOT(Token) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsLANG() instead
|
||||
IsLANG(Token, string) bool
|
||||
|
||||
// Deprecated: Use Model().TokenIdentifier().IsText() instead
|
||||
IsText(Token) bool
|
||||
|
||||
// Deprecated: Use Model().PrintTimings() instead
|
||||
// these are model-level performance metrics
|
||||
PrintTimings()
|
||||
|
||||
// Deprecated: Use Model().ResetTimings() instead
|
||||
// these are model-level performance metrics
|
||||
ResetTimings()
|
||||
|
||||
// SystemInfo returns the system information
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
|
|
@ -107,12 +187,29 @@ type Segment struct {
|
|||
|
||||
// The tokens of the segment.
|
||||
Tokens []Token
|
||||
|
||||
// True if the next segment is predicted as a speaker turn (tinydiarize)
|
||||
// It works only with the diarization supporting models (like small.en-tdrz.bin) with the diarization enabled
|
||||
// using Parameters.SetDiarize(true)
|
||||
SpeakerTurnNext bool
|
||||
}
|
||||
|
||||
func (s Segment) String() string {
|
||||
// foramt: [00:01:39.000 --> 00:01:50.000] And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
|
||||
return fmt.Sprintf("[%s --> %s] %s", s.Start.Truncate(time.Millisecond), s.End.Truncate(time.Millisecond), s.Text)
|
||||
}
|
||||
|
||||
// Token is a text or special token
|
||||
type Token struct {
|
||||
Id int
|
||||
Text string
|
||||
P float32
|
||||
// ID of the token
|
||||
Id int
|
||||
|
||||
// Text of the token
|
||||
Text string
|
||||
|
||||
// Probability of the token
|
||||
P float32
|
||||
|
||||
// Timestamp of the token
|
||||
Start, End time.Duration
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
package whisper
|
||||
|
||||
import low "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
|
||||
// DisableLogs disables all C-side logging from whisper.cpp and ggml.
|
||||
// Call once early in your program before creating models/contexts.
|
||||
func DisableLogs() {
|
||||
low.DisableLogs()
|
||||
}
|
||||
|
|
@ -3,99 +3,177 @@ package whisper
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
low "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TYPES
|
||||
|
||||
type model struct {
|
||||
path string
|
||||
ctx *whisper.Context
|
||||
type ModelContext struct {
|
||||
path string
|
||||
ca *ctxAccessor
|
||||
tokId *tokenIdentifier
|
||||
}
|
||||
|
||||
// Make sure model adheres to the interface
|
||||
var _ Model = (*model)(nil)
|
||||
var _ Model = (*ModelContext)(nil)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
// Timings is a compact, high-level timing snapshot in milliseconds
|
||||
type Timings struct {
|
||||
SampleMS float32
|
||||
EncodeMS float32
|
||||
DecodeMS float32
|
||||
BatchdMS float32
|
||||
PromptMS float32
|
||||
}
|
||||
|
||||
// Deprecated: Use NewModelContext instead
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
return NewModelContext(path)
|
||||
}
|
||||
|
||||
// NewModelContext creates a new model context
|
||||
|
||||
func NewModelContext(
|
||||
path string,
|
||||
) (*ModelContext, error) {
|
||||
return NewModelContextWithParams(
|
||||
path,
|
||||
NewModelContextParams(),
|
||||
)
|
||||
}
|
||||
|
||||
// NewModelContextWithParams creates a new model context with custom initialization params
|
||||
func NewModelContextWithParams(
|
||||
path string,
|
||||
params ModelContextParams,
|
||||
) (*ModelContext, error) {
|
||||
model := new(ModelContext)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
ctx := low.Whisper_init_with_params(path, params.toLow())
|
||||
if ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
}
|
||||
|
||||
model.ca = newCtxAccessor(ctx)
|
||||
model.tokId = newTokenIdentifier(model.ca)
|
||||
model.path = path
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
if model.ctx != nil {
|
||||
model.ctx.Whisper_free()
|
||||
}
|
||||
|
||||
// Release resources
|
||||
model.ctx = nil
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
func (model *ModelContext) Close() error {
|
||||
return model.ca.close()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// STRINGIFY
|
||||
func (model *ModelContext) ctxAccessor() *ctxAccessor {
|
||||
return model.ca
|
||||
}
|
||||
|
||||
func (model *model) String() string {
|
||||
func (model *ModelContext) String() string {
|
||||
str := "<whisper.model"
|
||||
if model.ctx != nil {
|
||||
if model.ca != nil {
|
||||
str += fmt.Sprintf(" model=%q", model.path)
|
||||
}
|
||||
|
||||
return str + ">"
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return true if model is multilingual (language and translation options are supported)
|
||||
func (model *model) IsMultilingual() bool {
|
||||
return model.ctx.Whisper_is_multilingual() != 0
|
||||
func (model *ModelContext) IsMultilingual() bool {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return ctx.Whisper_is_multilingual() != 0
|
||||
}
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
for i := 0; i < whisper.Whisper_lang_max_id(); i++ {
|
||||
str := whisper.Whisper_lang_str(i)
|
||||
if model.ctx.Whisper_lang_id(str) >= 0 {
|
||||
func (model *ModelContext) Languages() []string {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]string, 0, low.Whisper_lang_max_id())
|
||||
for i := 0; i < low.Whisper_lang_max_id(); i++ {
|
||||
str := low.Whisper_lang_str(i)
|
||||
if ctx.Whisper_lang_id(str) >= 0 {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (model *model) NewContext() (Context, error) {
|
||||
if model.ctx == nil {
|
||||
return nil, ErrInternalAppError
|
||||
// NewContext creates a new speech-to-text context.
|
||||
// Each context is backed by an isolated whisper_state for safe concurrent processing.
|
||||
func (model *ModelContext) NewContext() (Context, error) {
|
||||
// Create new context with default params
|
||||
params, err := NewParameters(model, SAMPLING_GREEDY, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create new context
|
||||
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
params.SetNoContext(true)
|
||||
|
||||
// Return new context
|
||||
return newContext(model, params)
|
||||
// Return new context (stateless for backward compatibility with timings)
|
||||
return NewStatelessContext(
|
||||
model,
|
||||
params,
|
||||
)
|
||||
}
|
||||
|
||||
// PrintTimings prints the model performance timings to stdout.
|
||||
func (model *ModelContext) PrintTimings() {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Whisper_print_timings()
|
||||
}
|
||||
|
||||
// ResetTimings resets the model performance timing counters.
|
||||
func (model *ModelContext) ResetTimings() {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Whisper_reset_timings()
|
||||
}
|
||||
|
||||
// GetTimings returns a compact snapshot of model-level processing timings.
|
||||
//
|
||||
// Behavior notes:
|
||||
// - Stateless contexts (created via ModelContext.NewContext or NewStatelessContext)
|
||||
// update model-level timings during Process. After a stateless Process call,
|
||||
// the returned timings are expected to be non-zero (ok == true).
|
||||
// - Stateful contexts (created via NewStatefulContext) use a per-state backend
|
||||
// and do not affect model-level timings. After a stateful Process call,
|
||||
// the returned timings are expected to be zero values (fields equal 0) or
|
||||
// the call may return ok == false depending on the underlying implementation.
|
||||
//
|
||||
// Use ResetTimings before measurement to clear previous values.
|
||||
func (model *ModelContext) GetTimings() (Timings, bool) {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return Timings{}, false
|
||||
}
|
||||
if t, ok := ctx.Whisper_get_timings_go(); ok {
|
||||
return Timings{
|
||||
SampleMS: t.SampleMS,
|
||||
EncodeMS: t.EncodeMS,
|
||||
DecodeMS: t.DecodeMS,
|
||||
BatchdMS: t.BatchdMS,
|
||||
PromptMS: t.PromptMS,
|
||||
}, true
|
||||
}
|
||||
return Timings{}, false
|
||||
}
|
||||
|
||||
func (model *ModelContext) tokenIdentifier() *tokenIdentifier {
|
||||
return model.tokId
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
low "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
type ModelContextParams struct {
|
||||
p low.ContextParams
|
||||
}
|
||||
|
||||
func NewModelContextParams() ModelContextParams {
|
||||
return ModelContextParams{
|
||||
p: low.Whisper_context_default_params(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ModelContextParams) SetUseGPU(v bool) {
|
||||
p.p.SetUseGPU(v)
|
||||
}
|
||||
|
||||
func (p *ModelContextParams) SetGPUDevice(n int) {
|
||||
p.p.SetGPUDevice(n)
|
||||
}
|
||||
|
||||
func (p *ModelContextParams) toLow() low.ContextParams {
|
||||
return p.p
|
||||
}
|
||||
|
|
@ -13,7 +13,7 @@ func TestNew(t *testing.T) {
|
|||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
})
|
||||
|
||||
|
|
@ -42,20 +42,34 @@ func TestNewContext(t *testing.T) {
|
|||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.NoError(err)
|
||||
assert.NotNil(context)
|
||||
}
|
||||
|
||||
func TestNewContext_ClosedModel(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
assert.NoError(model.Close())
|
||||
|
||||
context, err := model.NewContext()
|
||||
assert.ErrorIs(err, whisper.ErrInternalAppError)
|
||||
assert.ErrorIs(err, whisper.ErrModelClosed)
|
||||
assert.Nil(context)
|
||||
}
|
||||
|
||||
func TestIsMultilingual(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
isMultilingual := model.IsMultilingual()
|
||||
|
||||
|
|
@ -71,7 +85,7 @@ func TestLanguages(t *testing.T) {
|
|||
model, err := whisper.New(ModelPath)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(model)
|
||||
defer model.Close()
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
expectedLanguages := []string{
|
||||
"en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
// Parameters is a high-level wrapper that implements the Parameters interface
|
||||
// and delegates to the underlying low-level whisper.Params.
|
||||
type Parameters struct {
|
||||
p *whisper.Params
|
||||
}
|
||||
|
||||
func defaultParamsConfigure(params *Parameters) {
|
||||
params.SetTranslate(false)
|
||||
params.SetPrintSpecial(false)
|
||||
params.SetPrintProgress(false)
|
||||
params.SetPrintRealtime(false)
|
||||
params.SetPrintTimestamps(false)
|
||||
// Default behavior backward compatibility
|
||||
params.SetThreads(uint(runtime.NumCPU()))
|
||||
params.SetNoContext(true)
|
||||
}
|
||||
|
||||
func NewParameters(
|
||||
model *ModelContext,
|
||||
sampling SamplingStrategy,
|
||||
configure ParamsConfigure,
|
||||
) (*Parameters, error) {
|
||||
ctx, err := model.ca.context()
|
||||
if err != nil {
|
||||
return nil, ErrModelClosed
|
||||
}
|
||||
|
||||
p := ctx.Whisper_full_default_params(whisper.SamplingStrategy(sampling))
|
||||
safeParams := &Parameters{
|
||||
p: &p,
|
||||
}
|
||||
|
||||
defaultParamsConfigure(safeParams)
|
||||
|
||||
if configure != nil {
|
||||
configure(safeParams)
|
||||
}
|
||||
|
||||
return safeParams, nil
|
||||
}
|
||||
|
||||
func (w *Parameters) SetTranslate(v bool) { w.p.SetTranslate(v) }
|
||||
func (w *Parameters) SetSplitOnWord(v bool) { w.p.SetSplitOnWord(v) }
|
||||
func (w *Parameters) SetThreads(v uint) { w.p.SetThreads(int(v)) }
|
||||
func (w *Parameters) SetOffset(d time.Duration) { w.p.SetOffset(int(d.Milliseconds())) }
|
||||
func (w *Parameters) SetDuration(d time.Duration) { w.p.SetDuration(int(d.Milliseconds())) }
|
||||
func (w *Parameters) SetTokenThreshold(t float32) { w.p.SetTokenThreshold(t) }
|
||||
func (w *Parameters) SetTokenSumThreshold(t float32) { w.p.SetTokenSumThreshold(t) }
|
||||
func (w *Parameters) SetMaxSegmentLength(n uint) { w.p.SetMaxSegmentLength(int(n)) }
|
||||
func (w *Parameters) SetTokenTimestamps(b bool) { w.p.SetTokenTimestamps(b) }
|
||||
func (w *Parameters) SetMaxTokensPerSegment(n uint) { w.p.SetMaxTokensPerSegment(int(n)) }
|
||||
func (w *Parameters) SetAudioCtx(n uint) { w.p.SetAudioCtx(int(n)) }
|
||||
func (w *Parameters) SetMaxContext(n int) { w.p.SetMaxContext(n) }
|
||||
func (w *Parameters) SetBeamSize(n int) { w.p.SetBeamSize(n) }
|
||||
func (w *Parameters) SetEntropyThold(t float32) { w.p.SetEntropyThold(t) }
|
||||
func (w *Parameters) SetInitialPrompt(prompt string) { w.p.SetInitialPrompt(prompt) }
|
||||
func (w *Parameters) SetCarryInitialPrompt(v bool) { w.p.SetCarryInitialPrompt(v) }
|
||||
func (w *Parameters) SetTemperature(t float32) { w.p.SetTemperature(t) }
|
||||
func (w *Parameters) SetTemperatureFallback(t float32) { w.p.SetTemperatureFallback(t) }
|
||||
func (w *Parameters) SetNoContext(v bool) { w.p.SetNoContext(v) }
|
||||
func (w *Parameters) SetPrintSpecial(v bool) { w.p.SetPrintSpecial(v) }
|
||||
func (w *Parameters) SetPrintProgress(v bool) { w.p.SetPrintProgress(v) }
|
||||
func (w *Parameters) SetPrintRealtime(v bool) { w.p.SetPrintRealtime(v) }
|
||||
func (w *Parameters) SetPrintTimestamps(v bool) { w.p.SetPrintTimestamps(v) }
|
||||
func (w *Parameters) SetDebugMode(v bool) { w.p.SetDebugMode(v) }
|
||||
|
||||
// Diarization (tinydiarize)
|
||||
func (w *Parameters) SetDiarize(v bool) { w.p.SetDiarize(v) }
|
||||
|
||||
// Voice Activity Detection (VAD)
|
||||
func (w *Parameters) SetVAD(v bool) { w.p.SetVAD(v) }
|
||||
func (w *Parameters) SetVADModelPath(p string) { w.p.SetVADModelPath(p) }
|
||||
func (w *Parameters) SetVADThreshold(t float32) { w.p.SetVADThreshold(t) }
|
||||
func (w *Parameters) SetVADMinSpeechMs(ms int) { w.p.SetVADMinSpeechMs(ms) }
|
||||
func (w *Parameters) SetVADMinSilenceMs(ms int) { w.p.SetVADMinSilenceMs(ms) }
|
||||
func (w *Parameters) SetVADMaxSpeechSec(s float32) { w.p.SetVADMaxSpeechSec(s) }
|
||||
func (w *Parameters) SetVADSpeechPadMs(ms int) { w.p.SetVADSpeechPadMs(ms) }
|
||||
func (w *Parameters) SetVADSamplesOverlap(sec float32) { w.p.SetVADSamplesOverlap(sec) }
|
||||
|
||||
func (w *Parameters) SetLanguage(lang string) error {
|
||||
if lang == "auto" {
|
||||
return w.p.SetLanguage(-1)
|
||||
}
|
||||
id := whisper.Whisper_lang_id_str(lang)
|
||||
if id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
}
|
||||
return w.p.SetLanguage(id)
|
||||
}
|
||||
|
||||
func (w *Parameters) SetSingleSegment(v bool) {
|
||||
w.p.SetSingleSegment(v)
|
||||
}
|
||||
|
||||
// Getter methods for Parameters interface
|
||||
func (w *Parameters) Language() string {
|
||||
id := w.p.Language()
|
||||
if id == -1 {
|
||||
return "auto"
|
||||
}
|
||||
|
||||
return whisper.Whisper_lang_str(id)
|
||||
}
|
||||
|
||||
func (w *Parameters) Threads() int {
|
||||
return w.p.Threads()
|
||||
}
|
||||
|
||||
func (w *Parameters) unsafeParams() (*whisper.Params, error) {
|
||||
return w.p, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
type StatefulContext struct {
|
||||
n int
|
||||
model *ModelContext
|
||||
st *whisperState
|
||||
params *Parameters
|
||||
}
|
||||
|
||||
// NewStatefulContext creates a new stateful context
|
||||
func NewStatefulContext(model *ModelContext, params *Parameters) (*StatefulContext, error) {
|
||||
if model == nil {
|
||||
return nil, errModelRequired
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return nil, errParametersRequired
|
||||
}
|
||||
|
||||
c := new(StatefulContext)
|
||||
c.model = model
|
||||
c.params = params
|
||||
|
||||
// allocate isolated state per context
|
||||
ctx, err := model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
st := ctx.Whisper_init_state()
|
||||
if st == nil {
|
||||
return nil, errUnableToCreateState
|
||||
}
|
||||
|
||||
c.st = newWhisperState(st)
|
||||
|
||||
// Return success
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DetectedLanguage returns the detected language for the current context data
|
||||
func (context *StatefulContext) DetectedLanguage() string {
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
st, err := context.st.unsafeState()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return whisper.Whisper_lang_str(
|
||||
ctx.Whisper_full_lang_id_from_state(
|
||||
st,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// Close frees the whisper state and marks the context as closed.
|
||||
func (context *StatefulContext) Close() error {
|
||||
return context.st.close()
|
||||
}
|
||||
|
||||
// Params returns a high-level parameters wrapper
|
||||
func (context *StatefulContext) Params() *Parameters {
|
||||
return context.params
|
||||
}
|
||||
|
||||
// ResetTimings resets the model performance timing counters.
|
||||
// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics.
|
||||
func (context *StatefulContext) ResetTimings() {
|
||||
context.model.ResetTimings()
|
||||
}
|
||||
|
||||
// PrintTimings prints the model performance timings to stdout.
|
||||
// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics.
|
||||
func (context *StatefulContext) PrintTimings() {
|
||||
context.model.PrintTimings()
|
||||
}
|
||||
|
||||
// SystemInfo returns the system information
|
||||
func (context *StatefulContext) SystemInfo() string {
|
||||
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
|
||||
context.params.Threads(),
|
||||
runtime.NumCPU(),
|
||||
whisper.Whisper_print_system_info(),
|
||||
)
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages for this context's state.
|
||||
func (context *StatefulContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
st, err := context.st.unsafeState()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
langProbs, err := ctx.Whisper_lang_auto_detect_with_state(st, offset_ms, n_threads)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return langProbs, nil
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *StatefulContext) Process(
|
||||
data []float32,
|
||||
callEncoderBegin EncoderBeginCallback,
|
||||
callNewSegment SegmentCallback,
|
||||
callProgress ProgressCallback,
|
||||
) error {
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if callNewSegment != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
lowLevelParams, err := context.params.unsafeParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st, err := context.st.unsafeState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ctx.Whisper_full_with_state(st, *lowLevelParams, data, callEncoderBegin,
|
||||
func(new int) {
|
||||
if callNewSegment != nil {
|
||||
num_segments := ctx.Whisper_full_n_segments_from_state(st)
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
callNewSegment(toSegmentFromState(ctx, st, i))
|
||||
}
|
||||
}
|
||||
}, func(progress int) {
|
||||
if callProgress != nil {
|
||||
callProgress(progress)
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextSegment returns the next segment from the context buffer
|
||||
func (context *StatefulContext) NextSegment() (Segment, error) {
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return Segment{}, err
|
||||
}
|
||||
|
||||
st, err := context.st.unsafeState()
|
||||
if err != nil {
|
||||
return Segment{}, err
|
||||
}
|
||||
|
||||
if context.n >= ctx.Whisper_full_n_segments_from_state(st) {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
result := toSegmentFromState(ctx, st, context.n)
|
||||
context.n++
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (context *StatefulContext) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Token helpers
|
||||
// Deprecated: Use Model.IsText() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsText(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsText(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsBEG() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsBEG(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsBEG(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsSOT() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsSOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsSOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsEOT() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsEOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsEOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsPREV() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsPREV(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsPREV(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsSOLM(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsSOLM(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsNOT() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsNOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsNOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
func (context *StatefulContext) SetLanguage(lang string) error {
|
||||
if context.model.ctxAccessor().isClosed() {
|
||||
// TODO: remove this logic after deprecating the ErrInternalAppError
|
||||
return ErrModelClosed
|
||||
}
|
||||
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
|
||||
return context.params.SetLanguage(lang)
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsLANG() instead - token checking is model-specific.
|
||||
func (context *StatefulContext) IsLANG(t Token, lang string) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsLANG(t, lang)
|
||||
return result
|
||||
}
|
||||
|
||||
// State-backed helper functions
|
||||
func toSegmentFromState(ctx *whisper.Context, st *whisper.State, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text_from_state(st, n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0_from_state(st, n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1_from_state(st, n)) * time.Millisecond * 10,
|
||||
Tokens: toTokensFromState(ctx, st, n),
|
||||
SpeakerTurnNext: ctx.Whisper_full_get_segment_speaker_turn_next_from_state(st, n),
|
||||
}
|
||||
}
|
||||
|
||||
func toTokensFromState(ctx *whisper.Context, st *whisper.State, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens_from_state(st, n))
|
||||
|
||||
for i := 0; i < len(result); i++ {
|
||||
data := ctx.Whisper_full_get_token_data_from_state(st, n, i)
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id_from_state(st, n, i)),
|
||||
Text: ctx.Whisper_full_get_token_text_from_state(st, n, i),
|
||||
P: ctx.Whisper_full_get_token_p_from_state(st, n, i),
|
||||
Start: time.Duration(data.T0()) * time.Millisecond * 10,
|
||||
End: time.Duration(data.T1()) * time.Millisecond * 10,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().Language() instead
|
||||
func (context *StatefulContext) Language() string {
|
||||
return context.params.Language()
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetAudioCtx() instead
|
||||
func (context *StatefulContext) SetAudioCtx(n uint) {
|
||||
context.params.SetAudioCtx(n)
|
||||
}
|
||||
|
||||
// SetBeamSize implements Context.
|
||||
// Deprecated: Use Params().SetBeamSize() instead
|
||||
func (context *StatefulContext) SetBeamSize(v int) {
|
||||
context.params.SetBeamSize(v)
|
||||
}
|
||||
|
||||
// SetDuration implements Context.
|
||||
// Deprecated: Use Params().SetDuration() instead
|
||||
func (context *StatefulContext) SetDuration(v time.Duration) {
|
||||
context.params.SetDuration(v)
|
||||
}
|
||||
|
||||
// SetEntropyThold implements Context.
|
||||
// Deprecated: Use Params().SetEntropyThold() instead
|
||||
func (context *StatefulContext) SetEntropyThold(v float32) {
|
||||
context.params.SetEntropyThold(v)
|
||||
}
|
||||
|
||||
// SetInitialPrompt implements Context.
|
||||
// Deprecated: Use Params().SetInitialPrompt() instead
|
||||
func (context *StatefulContext) SetInitialPrompt(v string) {
|
||||
context.params.SetInitialPrompt(v)
|
||||
}
|
||||
|
||||
// SetMaxContext implements Context.
|
||||
// Deprecated: Use Params().SetMaxContext() instead
|
||||
func (context *StatefulContext) SetMaxContext(v int) {
|
||||
context.params.SetMaxContext(v)
|
||||
}
|
||||
|
||||
// SetMaxSegmentLength implements Context.
|
||||
// Deprecated: Use Params().SetMaxSegmentLength() instead
|
||||
func (context *StatefulContext) SetMaxSegmentLength(v uint) {
|
||||
context.params.SetMaxSegmentLength(v)
|
||||
}
|
||||
|
||||
// SetMaxTokensPerSegment implements Context.
|
||||
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
|
||||
func (context *StatefulContext) SetMaxTokensPerSegment(v uint) {
|
||||
context.params.SetMaxTokensPerSegment(v)
|
||||
}
|
||||
|
||||
// SetOffset implements Context.
|
||||
// Deprecated: Use Params().SetOffset() instead
|
||||
func (context *StatefulContext) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(v)
|
||||
}
|
||||
|
||||
// SetSplitOnWord implements Context.
|
||||
// Deprecated: Use Params().SetSplitOnWord() instead
|
||||
func (context *StatefulContext) SetSplitOnWord(v bool) {
|
||||
context.params.SetSplitOnWord(v)
|
||||
}
|
||||
|
||||
// SetTemperature implements Context.
|
||||
// Deprecated: Use Params().SetTemperature() instead
|
||||
func (context *StatefulContext) SetTemperature(v float32) {
|
||||
context.params.SetTemperature(v)
|
||||
}
|
||||
|
||||
// SetTemperatureFallback implements Context.
|
||||
// Deprecated: Use Params().SetTemperatureFallback() instead
|
||||
func (context *StatefulContext) SetTemperatureFallback(v float32) {
|
||||
context.params.SetTemperatureFallback(v)
|
||||
}
|
||||
|
||||
// SetThreads implements Context.
|
||||
// Deprecated: Use Params().SetThreads() instead
|
||||
func (context *StatefulContext) SetThreads(v uint) {
|
||||
context.params.SetThreads(v)
|
||||
}
|
||||
|
||||
// SetTokenSumThreshold implements Context.
|
||||
// Deprecated: Use Params().SetTokenSumThreshold() instead
|
||||
func (context *StatefulContext) SetTokenSumThreshold(v float32) {
|
||||
context.params.SetTokenSumThreshold(v)
|
||||
}
|
||||
|
||||
// SetTokenThreshold implements Context.
|
||||
// Deprecated: Use Params().SetTokenThreshold() instead
|
||||
func (context *StatefulContext) SetTokenThreshold(v float32) {
|
||||
context.params.SetTokenThreshold(v)
|
||||
}
|
||||
|
||||
// SetTokenTimestamps implements Context.
|
||||
// Deprecated: Use Params().SetTokenTimestamps() instead
|
||||
func (context *StatefulContext) SetTokenTimestamps(v bool) {
|
||||
context.params.SetTokenTimestamps(v)
|
||||
}
|
||||
|
||||
// SetTranslate implements Context.
|
||||
// Deprecated: Use Params().SetTranslate() instead
|
||||
func (context *StatefulContext) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// VAD methods - implement Context interface
|
||||
// Deprecated: Use Params().SetVAD() instead
|
||||
func (context *StatefulContext) SetVAD(v bool) {
|
||||
context.params.SetVAD(v)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADModelPath() instead
|
||||
func (context *StatefulContext) SetVADModelPath(path string) {
|
||||
context.params.SetVADModelPath(path)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADThreshold() instead
|
||||
func (context *StatefulContext) SetVADThreshold(t float32) {
|
||||
context.params.SetVADThreshold(t)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMinSpeechMs() instead
|
||||
func (context *StatefulContext) SetVADMinSpeechMs(ms int) {
|
||||
context.params.SetVADMinSpeechMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMinSilenceMs() instead
|
||||
func (context *StatefulContext) SetVADMinSilenceMs(ms int) {
|
||||
context.params.SetVADMinSilenceMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
|
||||
func (context *StatefulContext) SetVADMaxSpeechSec(s float32) {
|
||||
context.params.SetVADMaxSpeechSec(s)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADSpeechPadMs() instead
|
||||
func (context *StatefulContext) SetVADSpeechPadMs(ms int) {
|
||||
context.params.SetVADSpeechPadMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADSamplesOverlap() instead
|
||||
func (context *StatefulContext) SetVADSamplesOverlap(sec float32) {
|
||||
context.params.SetVADSamplesOverlap(sec)
|
||||
}
|
||||
|
||||
// Make stateful context compatible with the old deprecated interface for
|
||||
// the simple migration into multi-threaded processing.
|
||||
var _ Context = (*StatefulContext)(nil)
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Stateful-specific: parallel processing supported
|
||||
func TestContext_Parallel_DifferentInputs_Stateful(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
assert.Greater(len(data), 10)
|
||||
|
||||
// Create half-sample (second half)
|
||||
half := make([]float32, len(data)/2)
|
||||
copy(half, data[len(data)/2:])
|
||||
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = model.Close() }()
|
||||
|
||||
params1 := helperNewParams(t, model, nil)
|
||||
params2 := helperNewParams(t, model, nil)
|
||||
|
||||
ctx1, err := whisper.NewStatefulContext(model, params1)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = ctx1.Close() }()
|
||||
ctx2, err := whisper.NewStatefulContext(model, params2)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = ctx2.Close() }()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var first1, first2 string
|
||||
var e1, e2 error
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
e1 = ctx1.Process(data, nil, nil, nil)
|
||||
if e1 == nil {
|
||||
seg, err := ctx1.NextSegment()
|
||||
if err == nil {
|
||||
first1 = seg.Text
|
||||
} else {
|
||||
e1 = err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
e2 = ctx2.Process(half, nil, nil, nil)
|
||||
if e2 == nil {
|
||||
seg, err := ctx2.NextSegment()
|
||||
if err == nil {
|
||||
first2 = seg.Text
|
||||
} else {
|
||||
e2 = err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(e1)
|
||||
assert.NoError(e2)
|
||||
assert.NotEmpty(first1)
|
||||
assert.NotEmpty(first2)
|
||||
assert.NotEqual(first1, first2, "first segments should differ for different inputs")
|
||||
}
|
||||
|
|
@ -0,0 +1,418 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Bindings
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
)
|
||||
|
||||
type StatelessContext struct {
|
||||
n int
|
||||
model *ModelContext
|
||||
params *Parameters
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewStatelessContext creates a new stateless context backed by the model's context
|
||||
func NewStatelessContext(model *ModelContext, params *Parameters) (*StatelessContext, error) {
|
||||
if model == nil {
|
||||
return nil, errModelRequired
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return nil, errParametersRequired
|
||||
}
|
||||
|
||||
// Ensure model context is available
|
||||
if _, err := model.ctxAccessor().context(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := new(StatelessContext)
|
||||
c.model = model
|
||||
c.params = params
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DetectedLanguage returns the detected language for the current context data
|
||||
func (context *StatelessContext) DetectedLanguage() string {
|
||||
if context.closed {
|
||||
return ""
|
||||
}
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return whisper.Whisper_lang_str(ctx.Whisper_full_lang_id())
|
||||
}
|
||||
|
||||
// Close marks the context as closed.
|
||||
func (context *StatelessContext) Close() error {
|
||||
context.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Params returns a high-level parameters wrapper
|
||||
func (context *StatelessContext) Params() *Parameters {
|
||||
return context.params
|
||||
}
|
||||
|
||||
// ResetTimings resets the model performance timing counters.
|
||||
// Deprecated: Use Model.ResetTimings() instead - these are model-level performance metrics.
|
||||
func (context *StatelessContext) ResetTimings() {
|
||||
context.model.ResetTimings()
|
||||
}
|
||||
|
||||
// PrintTimings prints the model performance timings to stdout.
|
||||
// Deprecated: Use Model.PrintTimings() instead - these are model-level performance metrics.
|
||||
func (context *StatelessContext) PrintTimings() {
|
||||
context.model.PrintTimings()
|
||||
}
|
||||
|
||||
// SystemInfo returns the system information
|
||||
func (context *StatelessContext) SystemInfo() string {
|
||||
return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n",
|
||||
context.params.Threads(),
|
||||
runtime.NumCPU(),
|
||||
whisper.Whisper_print_system_info(),
|
||||
)
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// Returns the probabilities of all languages for this context.
|
||||
func (context *StatelessContext) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
|
||||
if context.closed {
|
||||
return nil, ErrModelClosed
|
||||
}
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
langProbs, err := ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return langProbs, nil
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *StatelessContext) Process(
|
||||
data []float32,
|
||||
callEncoderBegin EncoderBeginCallback,
|
||||
callNewSegment SegmentCallback,
|
||||
callProgress ProgressCallback,
|
||||
) error {
|
||||
if context.closed {
|
||||
return ErrModelClosed
|
||||
}
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Concurrency guard: prevent concurrent stateless processing on shared model ctx
|
||||
k := modelKey(context.model)
|
||||
if !gate().Acquire(k) {
|
||||
return ErrStatelessBusy
|
||||
}
|
||||
defer gate().Release(k)
|
||||
|
||||
// If the callback is defined then we force on single_segment mode
|
||||
if callNewSegment != nil {
|
||||
context.params.SetSingleSegment(true)
|
||||
}
|
||||
|
||||
lowLevelParams, err := context.params.unsafeParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ctx.Whisper_full(*lowLevelParams, data, callEncoderBegin,
|
||||
func(new int) {
|
||||
if callNewSegment != nil {
|
||||
num_segments := ctx.Whisper_full_n_segments()
|
||||
s0 := num_segments - new
|
||||
for i := s0; i < num_segments; i++ {
|
||||
callNewSegment(toSegmentFromContext(ctx, i))
|
||||
}
|
||||
}
|
||||
}, func(progress int) {
|
||||
if callProgress != nil {
|
||||
callProgress(progress)
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextSegment returns the next segment from the context buffer
|
||||
func (context *StatelessContext) NextSegment() (Segment, error) {
|
||||
if context.closed {
|
||||
return Segment{}, ErrModelClosed
|
||||
}
|
||||
ctx, err := context.model.ctxAccessor().context()
|
||||
if err != nil {
|
||||
return Segment{}, err
|
||||
}
|
||||
|
||||
if context.n >= ctx.Whisper_full_n_segments() {
|
||||
return Segment{}, io.EOF
|
||||
}
|
||||
|
||||
result := toSegmentFromContext(ctx, context.n)
|
||||
context.n++
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (context *StatelessContext) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Token helpers
|
||||
// Deprecated: Use Model.IsText() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsText(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsText(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsBEG() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsBEG(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsBEG(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsSOT() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsSOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsSOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsEOT() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsEOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsEOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsPREV() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsPREV(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsPREV(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsSOLM() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsSOLM(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsSOLM(t)
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsNOT() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsNOT(t Token) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsNOT(t)
|
||||
return result
|
||||
}
|
||||
|
||||
func (context *StatelessContext) SetLanguage(lang string) error {
|
||||
if context.closed || context.model.ctxAccessor().isClosed() {
|
||||
return ErrModelClosed
|
||||
}
|
||||
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
|
||||
return context.params.SetLanguage(lang)
|
||||
}
|
||||
|
||||
// Deprecated: Use Model.IsLANG() instead - token checking is model-specific.
|
||||
func (context *StatelessContext) IsLANG(t Token, lang string) bool {
|
||||
result, _ := context.model.tokenIdentifier().IsLANG(t, lang)
|
||||
return result
|
||||
}
|
||||
|
||||
// Context-backed helper functions
|
||||
func toSegmentFromContext(ctx *whisper.Context, n int) Segment {
|
||||
return Segment{
|
||||
Num: n,
|
||||
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
|
||||
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
|
||||
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
|
||||
Tokens: toTokensFromContext(ctx, n),
|
||||
SpeakerTurnNext: false, // speaker turn available only with state-backed accessors
|
||||
}
|
||||
}
|
||||
|
||||
func toTokensFromContext(ctx *whisper.Context, n int) []Token {
|
||||
result := make([]Token, ctx.Whisper_full_n_tokens(n))
|
||||
|
||||
for i := 0; i < len(result); i++ {
|
||||
data := ctx.Whisper_full_get_token_data(n, i)
|
||||
result[i] = Token{
|
||||
Id: int(ctx.Whisper_full_get_token_id(n, i)),
|
||||
Text: ctx.Whisper_full_get_token_text(n, i),
|
||||
P: ctx.Whisper_full_get_token_p(n, i),
|
||||
Start: time.Duration(data.T0()) * time.Millisecond * 10,
|
||||
End: time.Duration(data.T1()) * time.Millisecond * 10,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().Language() instead
|
||||
func (context *StatelessContext) Language() string {
|
||||
return context.params.Language()
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetAudioCtx() instead
|
||||
func (context *StatelessContext) SetAudioCtx(n uint) {
|
||||
context.params.SetAudioCtx(n)
|
||||
}
|
||||
|
||||
// SetBeamSize implements Context.
|
||||
// Deprecated: Use Params().SetBeamSize() instead
|
||||
func (context *StatelessContext) SetBeamSize(v int) {
|
||||
context.params.SetBeamSize(v)
|
||||
}
|
||||
|
||||
// SetDuration implements Context.
|
||||
// Deprecated: Use Params().SetDuration() instead
|
||||
func (context *StatelessContext) SetDuration(v time.Duration) {
|
||||
context.params.SetDuration(v)
|
||||
}
|
||||
|
||||
// SetEntropyThold implements Context.
|
||||
// Deprecated: Use Params().SetEntropyThold() instead
|
||||
func (context *StatelessContext) SetEntropyThold(v float32) {
|
||||
context.params.SetEntropyThold(v)
|
||||
}
|
||||
|
||||
// SetInitialPrompt implements Context.
|
||||
// Deprecated: Use Params().SetInitialPrompt() instead
|
||||
func (context *StatelessContext) SetInitialPrompt(v string) {
|
||||
context.params.SetInitialPrompt(v)
|
||||
}
|
||||
|
||||
// SetMaxContext implements Context.
|
||||
// Deprecated: Use Params().SetMaxContext() instead
|
||||
func (context *StatelessContext) SetMaxContext(v int) {
|
||||
context.params.SetMaxContext(v)
|
||||
}
|
||||
|
||||
// SetMaxSegmentLength implements Context.
|
||||
// Deprecated: Use Params().SetMaxSegmentLength() instead
|
||||
func (context *StatelessContext) SetMaxSegmentLength(v uint) {
|
||||
context.params.SetMaxSegmentLength(v)
|
||||
}
|
||||
|
||||
// SetMaxTokensPerSegment implements Context.
|
||||
// Deprecated: Use Params().SetMaxTokensPerSegment() instead
|
||||
func (context *StatelessContext) SetMaxTokensPerSegment(v uint) {
|
||||
context.params.SetMaxTokensPerSegment(v)
|
||||
}
|
||||
|
||||
// SetOffset implements Context.
|
||||
// Deprecated: Use Params().SetOffset() instead
|
||||
func (context *StatelessContext) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(v)
|
||||
}
|
||||
|
||||
// SetSplitOnWord implements Context.
|
||||
// Deprecated: Use Params().SetSplitOnWord() instead
|
||||
func (context *StatelessContext) SetSplitOnWord(v bool) {
|
||||
context.params.SetSplitOnWord(v)
|
||||
}
|
||||
|
||||
// SetTemperature implements Context.
|
||||
// Deprecated: Use Params().SetTemperature() instead
|
||||
func (context *StatelessContext) SetTemperature(v float32) {
|
||||
context.params.SetTemperature(v)
|
||||
}
|
||||
|
||||
// SetTemperatureFallback implements Context.
|
||||
// Deprecated: Use Params().SetTemperatureFallback() instead
|
||||
func (context *StatelessContext) SetTemperatureFallback(v float32) {
|
||||
context.params.SetTemperatureFallback(v)
|
||||
}
|
||||
|
||||
// SetThreads implements Context.
|
||||
// Deprecated: Use Params().SetThreads() instead
|
||||
func (context *StatelessContext) SetThreads(v uint) {
|
||||
context.params.SetThreads(v)
|
||||
}
|
||||
|
||||
// SetTokenSumThreshold implements Context.
|
||||
// Deprecated: Use Params().SetTokenSumThreshold() instead
|
||||
func (context *StatelessContext) SetTokenSumThreshold(v float32) {
|
||||
context.params.SetTokenSumThreshold(v)
|
||||
}
|
||||
|
||||
// SetTokenThreshold implements Context.
|
||||
// Deprecated: Use Params().SetTokenThreshold() instead
|
||||
func (context *StatelessContext) SetTokenThreshold(v float32) {
|
||||
context.params.SetTokenThreshold(v)
|
||||
}
|
||||
|
||||
// SetTokenTimestamps implements Context.
|
||||
// Deprecated: Use Params().SetTokenTimestamps() instead
|
||||
func (context *StatelessContext) SetTokenTimestamps(v bool) {
|
||||
context.params.SetTokenTimestamps(v)
|
||||
}
|
||||
|
||||
// SetTranslate implements Context.
|
||||
// Deprecated: Use Params().SetTranslate() instead
|
||||
func (context *StatelessContext) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// VAD methods - implement Context interface
|
||||
// Deprecated: Use Params().SetVAD() instead
|
||||
func (context *StatelessContext) SetVAD(v bool) {
|
||||
context.params.SetVAD(v)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADModelPath() instead
|
||||
func (context *StatelessContext) SetVADModelPath(path string) {
|
||||
context.params.SetVADModelPath(path)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADThreshold() instead
|
||||
func (context *StatelessContext) SetVADThreshold(t float32) {
|
||||
context.params.SetVADThreshold(t)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMinSpeechMs() instead
|
||||
func (context *StatelessContext) SetVADMinSpeechMs(ms int) {
|
||||
context.params.SetVADMinSpeechMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMinSilenceMs() instead
|
||||
func (context *StatelessContext) SetVADMinSilenceMs(ms int) {
|
||||
context.params.SetVADMinSilenceMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADMaxSpeechSec() instead
|
||||
func (context *StatelessContext) SetVADMaxSpeechSec(s float32) {
|
||||
context.params.SetVADMaxSpeechSec(s)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADSpeechPadMs() instead
|
||||
func (context *StatelessContext) SetVADSpeechPadMs(ms int) {
|
||||
context.params.SetVADSpeechPadMs(ms)
|
||||
}
|
||||
|
||||
// Deprecated: Use Params().SetVADSamplesOverlap() instead
|
||||
func (context *StatelessContext) SetVADSamplesOverlap(sec float32) {
|
||||
context.params.SetVADSamplesOverlap(sec)
|
||||
}
|
||||
|
||||
var _ Context = (*StatelessContext)(nil)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Ensure stateless contexts cannot process in parallel without isolation
|
||||
func TestStatelessContext_NotParallelSafe(t *testing.T) {
|
||||
data := helperLoadSample(t, SamplePath)
|
||||
|
||||
model, closeModel := helperNewModelContext(t)
|
||||
defer closeModel()
|
||||
|
||||
params := helperNewParams(t, model, nil)
|
||||
|
||||
// Create two stateless contexts sharing the same underlying model context
|
||||
ctx1, err := whisper.NewStatelessContext(model, params)
|
||||
assert.NoError(t, err)
|
||||
defer func() { _ = ctx1.Close() }()
|
||||
|
||||
ctx2, err := whisper.NewStatelessContext(model, params)
|
||||
assert.NoError(t, err)
|
||||
defer func() { _ = ctx2.Close() }()
|
||||
|
||||
// Run both in parallel - expect a panic or error from underlying whisper_full
|
||||
// We capture panics to assert the behavior.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var err1, err2 error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err1 = ctx1.Process(data, nil, nil, nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err2 = ctx2.Process(data, nil, nil, nil)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// At least one should return ErrStatelessBusy
|
||||
if err1 != whisper.ErrStatelessBusy && err2 != whisper.ErrStatelessBusy {
|
||||
t.Fatalf("expected ErrStatelessBusy when processing in parallel with StatelessContext, got err1=%v err2=%v", err1, err2)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func helperLoadSample(tb testing.TB, path string) []float32 {
|
||||
tb.Helper()
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
tb.Fatalf("open sample: %v", err)
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
if err != nil {
|
||||
tb.Fatalf("decode wav: %v", err)
|
||||
}
|
||||
if dec.NumChans != 1 {
|
||||
tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans)
|
||||
}
|
||||
return buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// helperLoadSampleWithMeta loads wav and returns samples with sample rate and channels
|
||||
func helperLoadSampleWithMeta(tb testing.TB, path string) ([]float32, int, int) {
|
||||
tb.Helper()
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
tb.Fatalf("open sample: %v", err)
|
||||
}
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
if err != nil {
|
||||
tb.Fatalf("decode wav: %v", err)
|
||||
}
|
||||
if dec.NumChans != 1 {
|
||||
tb.Fatalf("expected mono wav, got channels=%d", dec.NumChans)
|
||||
}
|
||||
return buf.AsFloat32Buffer().Data, int(dec.SampleRate), int(dec.NumChans)
|
||||
}
|
||||
|
||||
func helperNewModel(t *testing.T) (whisper.Model, func()) {
|
||||
t.Helper()
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
model, err := whisper.New(ModelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("load model: %v", err)
|
||||
}
|
||||
return model, func() { _ = model.Close() }
|
||||
}
|
||||
|
||||
func helperNewModelContext(t *testing.T) (*whisper.ModelContext, func()) {
|
||||
t.Helper()
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
model, err := whisper.NewModelContext(ModelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("load model ctx: %v", err)
|
||||
}
|
||||
return model, func() { _ = model.Close() }
|
||||
}
|
||||
|
||||
func helperNewParams(t *testing.T, model *whisper.ModelContext, configure whisper.ParamsConfigure) *whisper.Parameters {
|
||||
t.Helper()
|
||||
params, err := whisper.NewParameters(model, whisper.SAMPLING_GREEDY, configure)
|
||||
if err != nil {
|
||||
t.Fatalf("new params: %v", err)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func helperProcessOnce(t *testing.T, ctx whisper.Context, data []float32) {
|
||||
t.Helper()
|
||||
if err := ctx.Process(data, nil, nil, nil); err != nil {
|
||||
t.Fatalf("process: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func helperFirstSegmentText(t *testing.T, ctx whisper.Context) string {
|
||||
t.Helper()
|
||||
seg, err := ctx.NextSegment()
|
||||
if err != nil {
|
||||
t.Fatalf("next segment: %v", err)
|
||||
}
|
||||
return seg.Text
|
||||
}
|
||||
|
||||
// helperNewStatelessContext creates a fresh stateless context and returns a cleanup func
|
||||
func helperNewStatelessContext(t *testing.T) (whisper.Context, func()) {
|
||||
t.Helper()
|
||||
model, closeModel := helperNewModelContext(t)
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatelessContext(model, params)
|
||||
if err != nil {
|
||||
t.Fatalf("new stateless context: %v", err)
|
||||
}
|
||||
cleanup := func() {
|
||||
_ = ctx.Close()
|
||||
closeModel()
|
||||
}
|
||||
return ctx, cleanup
|
||||
}
|
||||
|
||||
// helperNewStatefulContext creates a fresh stateful context and returns a cleanup func
|
||||
func helperNewStatefulContext(t *testing.T) (whisper.Context, func()) {
|
||||
t.Helper()
|
||||
model, closeModel := helperNewModelContext(t)
|
||||
params := helperNewParams(t, model, nil)
|
||||
ctx, err := whisper.NewStatefulContext(model, params)
|
||||
if err != nil {
|
||||
t.Fatalf("new stateful context: %v", err)
|
||||
}
|
||||
cleanup := func() {
|
||||
_ = ctx.Close()
|
||||
closeModel()
|
||||
}
|
||||
return ctx, cleanup
|
||||
}
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
package whisper
|
||||
|
||||
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
|
||||
type tokenIdentifier struct {
|
||||
ctx *ctxAccessor
|
||||
}
|
||||
|
||||
func newTokenIdentifier(whisperContext *ctxAccessor) *tokenIdentifier {
|
||||
return &tokenIdentifier{
|
||||
ctx: whisperContext,
|
||||
}
|
||||
}
|
||||
|
||||
// Token type checking methods (model-specific vocabulary)
|
||||
func (ti *tokenIdentifier) IsBEG(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_beg(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsEOT(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_eot(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsSOT(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_sot(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsPREV(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_prev(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsSOLM(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_solm(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsNOT(t Token) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_not(), nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsLANG(t Token, lang string) (bool, error) {
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if id := ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == ctx.Whisper_token_lang(id), nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (ti *tokenIdentifier) IsText(t Token) (bool, error) {
|
||||
// Check if it's any of the special tokens
|
||||
if isBeg, _ := ti.IsBEG(t); isBeg {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if isSot, _ := ti.IsSOT(t); isSot {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
ctx, err := ti.ctx.context()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if whisper.Token(t.Id) >= ctx.Whisper_token_eot() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if isPrev, _ := ti.IsPREV(t); isPrev {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if isSolm, _ := ti.IsSOLM(t); isSolm {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if isNot, _ := ti.IsNOT(t); isNot {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -1,6 +1,16 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
ModelPath = "../../models/ggml-small.en.bin"
|
||||
ModelPath = "../../models/ggml-tiny.en.bin"
|
||||
SamplePath = "../../samples/jfk.wav"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// whisper.DisableLogs()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
package whisper
|
||||
|
||||
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
|
||||
type ctxAccessor struct {
|
||||
ctx *whisper.Context
|
||||
}
|
||||
|
||||
func newCtxAccessor(ctx *whisper.Context) *ctxAccessor {
|
||||
return &ctxAccessor{
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *ctxAccessor) close() error {
|
||||
if ctx.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx.ctx.Whisper_free()
|
||||
ctx.ctx = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *ctxAccessor) isClosed() bool {
|
||||
return ctx.ctx == nil
|
||||
}
|
||||
|
||||
func (ctx *ctxAccessor) context() (*whisper.Context, error) {
|
||||
if ctx.isClosed() {
|
||||
return nil, ErrModelClosed
|
||||
}
|
||||
|
||||
return ctx.ctx, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
w "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testModelPathCtx = "../../models/ggml-tiny.en.bin"
|
||||
|
||||
func TestWhisperCtx_NilWrapper(t *testing.T) {
|
||||
wctx := newCtxAccessor(nil)
|
||||
|
||||
assert.True(t, wctx.isClosed())
|
||||
|
||||
raw, err := wctx.context()
|
||||
assert.Nil(t, raw)
|
||||
require.ErrorIs(t, err, ErrModelClosed)
|
||||
|
||||
require.NoError(t, wctx.close())
|
||||
// idempotent
|
||||
require.NoError(t, wctx.close())
|
||||
}
|
||||
|
||||
func TestWhisperCtx_Lifecycle(t *testing.T) {
|
||||
if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", testModelPathCtx)
|
||||
}
|
||||
|
||||
raw := w.Whisper_init(testModelPathCtx)
|
||||
require.NotNil(t, raw)
|
||||
|
||||
wctx := newCtxAccessor(raw)
|
||||
assert.False(t, wctx.isClosed())
|
||||
|
||||
got, err := wctx.context()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
// close frees underlying ctx and marks closed
|
||||
require.NoError(t, wctx.close())
|
||||
assert.True(t, wctx.isClosed())
|
||||
|
||||
got, err = wctx.context()
|
||||
assert.Nil(t, got)
|
||||
require.ErrorIs(t, err, ErrModelClosed)
|
||||
|
||||
// idempotent
|
||||
require.NoError(t, wctx.close())
|
||||
// no further free; raw already freed by wctx.Close()
|
||||
}
|
||||
|
||||
func TestWhisperCtx_FromModelLifecycle(t *testing.T) {
|
||||
if _, err := os.Stat(testModelPathCtx); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", testModelPathCtx)
|
||||
}
|
||||
|
||||
modelNew, err := New(testModelPathCtx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, modelNew)
|
||||
|
||||
model := modelNew.(*ModelContext)
|
||||
|
||||
wc := model.ctxAccessor()
|
||||
require.NotNil(t, wc)
|
||||
|
||||
// Should be usable before model.Close
|
||||
raw, err := wc.context()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, raw)
|
||||
|
||||
// Close model should close underlying context
|
||||
require.NoError(t, model.Close())
|
||||
|
||||
assert.True(t, wc.isClosed())
|
||||
raw, err = wc.context()
|
||||
assert.Nil(t, raw)
|
||||
require.ErrorIs(t, err, ErrModelClosed)
|
||||
|
||||
// Idempotent close on wrapper
|
||||
require.NoError(t, wc.close())
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
package whisper
|
||||
|
||||
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
|
||||
type whisperState struct {
|
||||
state *whisper.State
|
||||
}
|
||||
|
||||
func newWhisperState(state *whisper.State) *whisperState {
|
||||
return &whisperState{
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *whisperState) close() error {
|
||||
if s.state == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.state.Whisper_free_state()
|
||||
s.state = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *whisperState) unsafeState() (*whisper.State, error) {
|
||||
if s.state == nil {
|
||||
return nil, ErrModelClosed
|
||||
}
|
||||
|
||||
return s.state, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
package whisper
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
w "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
assert "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testModelPathState = "../../models/ggml-tiny.en.bin"
|
||||
|
||||
func TestWhisperState_NilWrapper(t *testing.T) {
|
||||
ws := newWhisperState(nil)
|
||||
|
||||
state, err := ws.unsafeState()
|
||||
assert.Nil(t, state)
|
||||
require.ErrorIs(t, err, ErrModelClosed)
|
||||
|
||||
require.NoError(t, ws.close())
|
||||
// idempotent
|
||||
require.NoError(t, ws.close())
|
||||
}
|
||||
|
||||
func TestWhisperState_Lifecycle(t *testing.T) {
|
||||
if _, err := os.Stat(testModelPathState); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", testModelPathState)
|
||||
}
|
||||
|
||||
ctx := w.Whisper_init(testModelPathState)
|
||||
require.NotNil(t, ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
state := ctx.Whisper_init_state()
|
||||
require.NotNil(t, state)
|
||||
|
||||
ws := newWhisperState(state)
|
||||
|
||||
got, err := ws.unsafeState()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
// close frees underlying state and marks closed
|
||||
require.NoError(t, ws.close())
|
||||
|
||||
got, err = ws.unsafeState()
|
||||
assert.Nil(t, got)
|
||||
require.ErrorIs(t, err, ErrModelClosed)
|
||||
|
||||
// idempotent
|
||||
require.NoError(t, ws.close())
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package whisper
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
#cgo darwin LDFLAGS: -lggml-metal -lggml-blas
|
||||
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
|
||||
#include <whisper.h>
|
||||
#include <ggml.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
|
|
@ -59,6 +61,22 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
|
|||
params.progress_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
|
||||
// Disable all C-side logging (whisper.cpp and ggml)
|
||||
static void go_cb_log_disable(enum ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level; (void) text; (void) user_data;
|
||||
}
|
||||
|
||||
static void whisper_log_disable_all(void) {
|
||||
ggml_log_set(go_cb_log_disable, NULL);
|
||||
whisper_log_set(go_cb_log_disable, NULL);
|
||||
}
|
||||
|
||||
// Enable default logging (stdout) for whisper.cpp and ggml
|
||||
static void whisper_log_enable_default(void) {
|
||||
ggml_log_set(NULL, NULL);
|
||||
whisper_log_set(NULL, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
|
|
@ -67,10 +85,13 @@ import "C"
|
|||
|
||||
type (
|
||||
Context C.struct_whisper_context
|
||||
State C.struct_whisper_state
|
||||
Token C.whisper_token
|
||||
TokenData C.struct_whisper_token_data
|
||||
SamplingStrategy C.enum_whisper_sampling_strategy
|
||||
Params C.struct_whisper_full_params
|
||||
Timings C.struct_whisper_timings
|
||||
ContextParams C.struct_whisper_context_params
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -96,6 +117,12 @@ var (
|
|||
ErrInvalidLanguage = errors.New("invalid language")
|
||||
)
|
||||
|
||||
// DisableLogs disables all logging coming from the C libraries (whisper.cpp and ggml).
|
||||
// Call once early in program startup if you want to silence device/backend prints.
|
||||
func DisableLogs() {
|
||||
C.whisper_log_disable_all()
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
|
|
@ -111,11 +138,54 @@ func Whisper_init(path string) *Context {
|
|||
}
|
||||
}
|
||||
|
||||
// Whisper_context_default_params returns default model context params
|
||||
func Whisper_context_default_params() ContextParams {
|
||||
return ContextParams(C.whisper_context_default_params())
|
||||
}
|
||||
|
||||
// SetUseGPU enables or disables GPU acceleration on the model context (if available)
|
||||
func (p *ContextParams) SetUseGPU(v bool) {
|
||||
if v {
|
||||
p.use_gpu = C.bool(true)
|
||||
} else {
|
||||
p.use_gpu = C.bool(false)
|
||||
}
|
||||
}
|
||||
|
||||
// SetGPUDevice selects the GPU device index for the model context (CUDA)
|
||||
func (p *ContextParams) SetGPUDevice(n int) {
|
||||
p.gpu_device = C.int(n)
|
||||
}
|
||||
|
||||
// Whisper_init_with_params allocates and initializes a model using custom context params
|
||||
func Whisper_init_with_params(path string, params ContextParams) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// Allocates a new state associated with the context. Returns nil on failure.
|
||||
func (ctx *Context) Whisper_init_state() *State {
|
||||
if s := C.whisper_init_state((*C.struct_whisper_context)(ctx)); s != nil {
|
||||
return (*State)(s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the state.
|
||||
func (s *State) Whisper_free_state() {
|
||||
C.whisper_free_state((*C.struct_whisper_state)(s))
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the provided whisper context.
|
||||
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
||||
|
|
@ -126,6 +196,15 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram into the provided state.
|
||||
func (ctx *Context) Whisper_pcm_to_mel_with_state(state *State, data []float32, threads int) error {
|
||||
if C.whisper_pcm_to_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
|
||||
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||
// n_mel must be 80
|
||||
|
|
@ -137,6 +216,15 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Set a custom log mel spectrogram into the provided state.
|
||||
func (ctx *Context) Whisper_set_mel_with_state(state *State, data []float32, n_mel int) error {
|
||||
if C.whisper_set_mel_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
|
||||
// offset can be used to specify the offset of the first frame in the spectrogram.
|
||||
|
|
@ -148,6 +236,15 @@ func (ctx *Context) Whisper_encode(offset, threads int) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Run the Whisper encoder using the provided state.
|
||||
func (ctx *Context) Whisper_encode_with_state(state *State, offset, threads int) error {
|
||||
if C.whisper_encode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||
// Make sure to call whisper_encode() first.
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
|
|
@ -160,6 +257,15 @@ func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Run the Whisper decoder using the provided state.
|
||||
func (ctx *Context) Whisper_decode_with_state(state *State, tokens []Token, past, threads int) error {
|
||||
if C.whisper_decode_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the provided text into tokens. The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success
|
||||
func (ctx *Context) Whisper_tokenize(text string, tokens []Token) (int, error) {
|
||||
|
|
@ -181,6 +287,10 @@ func (ctx *Context) Whisper_lang_id(lang string) int {
|
|||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
func Whisper_lang_id_str(lang string) int {
|
||||
return int(C.whisper_lang_id(C.CString(lang)))
|
||||
}
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
func Whisper_lang_max_id() int {
|
||||
return int(C.whisper_lang_max_id())
|
||||
|
|
@ -205,6 +315,16 @@ func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float3
|
|||
}
|
||||
}
|
||||
|
||||
// Use mel data at offset_ms to auto-detect language using the provided state.
|
||||
func (ctx *Context) Whisper_lang_auto_detect_with_state(state *State, offset_ms, n_threads int) ([]float32, error) {
|
||||
probs := make([]float32, Whisper_lang_max_id()+1)
|
||||
if n := int(C.whisper_lang_auto_detect_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
|
||||
return nil, ErrAutoDetectFailed
|
||||
} else {
|
||||
return probs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len() int {
|
||||
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
|
@ -290,6 +410,32 @@ func (ctx *Context) Whisper_reset_timings() {
|
|||
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
|
||||
}
|
||||
|
||||
// TimingsGo is a Go-friendly copy of whisper_timings
|
||||
type TimingsGo struct {
|
||||
SampleMS float32
|
||||
EncodeMS float32
|
||||
DecodeMS float32
|
||||
BatchdMS float32
|
||||
PromptMS float32
|
||||
}
|
||||
|
||||
// Whisper_get_timings_go retrieves timing counters and converts them to TimingsGo
|
||||
func (ctx *Context) Whisper_get_timings_go() (TimingsGo, bool) {
|
||||
t := C.whisper_get_timings((*C.struct_whisper_context)(ctx))
|
||||
if t == nil {
|
||||
return TimingsGo{}, false
|
||||
}
|
||||
// The C struct is 5 consecutive floats; reinterpret and copy
|
||||
arr := (*[5]C.float)(unsafe.Pointer(t))
|
||||
return TimingsGo{
|
||||
SampleMS: float32(arr[0]),
|
||||
EncodeMS: float32(arr[1]),
|
||||
DecodeMS: float32(arr[2]),
|
||||
BatchdMS: float32(arr[3]),
|
||||
PromptMS: float32(arr[4]),
|
||||
}, true
|
||||
}
|
||||
|
||||
// Print system information
|
||||
func Whisper_print_system_info() string {
|
||||
return C.GoString(C.whisper_print_system_info())
|
||||
|
|
@ -323,6 +469,28 @@ func (ctx *Context) Whisper_full(
|
|||
}
|
||||
}
|
||||
|
||||
// Run the entire model using the provided state: PCM -> mel -> encoder -> decoder -> text
|
||||
func (ctx *Context) Whisper_full_with_state(
|
||||
state *State,
|
||||
params Params,
|
||||
samples []float32,
|
||||
encoderBeginCallback func() bool,
|
||||
newSegmentCallback func(int),
|
||||
progressCallback func(int),
|
||||
) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
registerProgressCallback(ctx, progressCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
defer registerProgressCallback(ctx, nil)
|
||||
if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
return ErrConversionFailed
|
||||
}
|
||||
}
|
||||
|
||||
// Split the input audio in chunks and process each chunk separately using whisper_full()
|
||||
// It seems this approach can offer some speedup in some cases.
|
||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
|
|
@ -357,102 +525,157 @@ func (ctx *Context) Whisper_full_n_segments() int {
|
|||
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_n_segments_from_state(state *State) int {
|
||||
return int(C.whisper_full_n_segments_from_state((*C.struct_whisper_state)(state)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_segment_t0_from_state(state *State, segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t0_from_state((*C.struct_whisper_state)(state), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the start and end time of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_segment_t1_from_state(state *State, segment int) int64 {
|
||||
return int64(C.whisper_full_get_segment_t1_from_state((*C.struct_whisper_state)(state), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the text of the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_segment_text_from_state(state *State, segment int) string {
|
||||
return C.GoString(C.whisper_full_get_segment_text_from_state((*C.struct_whisper_state)(state), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get number of tokens in the specified segment.
|
||||
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
|
||||
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_n_tokens_from_state(state *State, segment int) int {
|
||||
return int(C.whisper_full_n_tokens_from_state((*C.struct_whisper_state)(state), C.int(segment)))
|
||||
}
|
||||
|
||||
// Get the token text of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_token_text_from_state(state *State, segment int, token int) string {
|
||||
return C.GoString(C.whisper_full_get_token_text_from_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the token of the specified token index in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_token_id_from_state(state *State, segment int, token int) Token {
|
||||
return Token(C.whisper_full_get_token_id_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get token data for the specified token in the specified segment.
|
||||
// This contains probabilities, timestamps, etc.
|
||||
func (ctx *Context) Whisper_full_get_token_data(segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_token_data_from_state(state *State, segment int, token int) TokenData {
|
||||
return TokenData(C.whisper_full_get_token_data_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_get_token_p_from_state(state *State, segment int, token int) float32 {
|
||||
return float32(C.whisper_full_get_token_p_from_state((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_full_lang_id_from_state(state *State) int {
|
||||
return int(C.whisper_full_lang_id_from_state((*C.struct_whisper_state)(state)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_n_len_from_state(state *State) int {
|
||||
return int(C.whisper_n_len_from_state((*C.struct_whisper_state)(state)))
|
||||
}
|
||||
|
||||
func (ctx *Context) Whisper_get_logits_from_state(state *State) []float32 {
|
||||
return (*[1 << 30]float32)(unsafe.Pointer(C.whisper_get_logits_from_state((*C.struct_whisper_state)(state))))[:ctx.Whisper_n_vocab()]
|
||||
}
|
||||
|
||||
// Get whether the next segment is predicted as a speaker turn (tinydiarize)
|
||||
func (ctx *Context) Whisper_full_get_segment_speaker_turn_next_from_state(state *State, segment int) bool {
|
||||
return bool(C.whisper_full_get_segment_speaker_turn_next_from_state((*C.struct_whisper_state)(state), C.int(segment)))
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CALLBACKS
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbProgress = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
cbNewSegment sync.Map // map[unsafe.Pointer]func(int)
|
||||
cbProgress sync.Map // map[unsafe.Pointer]func(int)
|
||||
cbEncoderBegin sync.Map // map[unsafe.Pointer]func() bool
|
||||
)
|
||||
|
||||
func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
k := unsafe.Pointer(ctx)
|
||||
if fn == nil {
|
||||
delete(cbNewSegment, unsafe.Pointer(ctx))
|
||||
cbNewSegment.Delete(k)
|
||||
} else {
|
||||
cbNewSegment[unsafe.Pointer(ctx)] = fn
|
||||
cbNewSegment.Store(k, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func registerProgressCallback(ctx *Context, fn func(int)) {
|
||||
k := unsafe.Pointer(ctx)
|
||||
if fn == nil {
|
||||
delete(cbProgress, unsafe.Pointer(ctx))
|
||||
cbProgress.Delete(k)
|
||||
} else {
|
||||
cbProgress[unsafe.Pointer(ctx)] = fn
|
||||
cbProgress.Store(k, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
k := unsafe.Pointer(ctx)
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
cbEncoderBegin.Delete(k)
|
||||
} else {
|
||||
cbEncoderBegin[unsafe.Pointer(ctx)] = fn
|
||||
cbEncoderBegin.Store(k, fn)
|
||||
}
|
||||
}
|
||||
|
||||
//export callNewSegment
|
||||
func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
if fn, ok := cbNewSegment[user_data]; ok {
|
||||
fn(int(new))
|
||||
if v, ok := cbNewSegment.Load(user_data); ok {
|
||||
v.(func(int))(int(new))
|
||||
}
|
||||
}
|
||||
|
||||
//export callProgress
|
||||
func callProgress(user_data unsafe.Pointer, progress C.int) {
|
||||
if fn, ok := cbProgress[user_data]; ok {
|
||||
fn(int(progress))
|
||||
if v, ok := cbProgress.Load(user_data); ok {
|
||||
v.(func(int))(int(progress))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
if fn() {
|
||||
if v, ok := cbEncoderBegin.Load(user_data); ok {
|
||||
if v.(func() bool)() {
|
||||
return C.bool(true)
|
||||
} else {
|
||||
return C.bool(false)
|
||||
}
|
||||
return C.bool(false)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package whisper_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -13,10 +15,15 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ModelPath = "models/ggml-small.en.bin"
|
||||
ModelPath = "models/ggml-tiny.en.bin"
|
||||
SamplePath = "samples/jfk.wav"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// whisper.DisableLogs() // temporarily disabled to see error messages
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func Test_Whisper_000(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
|
|
@ -39,7 +46,7 @@ func Test_Whisper_001(t *testing.T) {
|
|||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
|
|
@ -89,7 +96,7 @@ func Test_Whisper_003(t *testing.T) {
|
|||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer fh.Close()
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
|
|
@ -111,3 +118,157 @@ func Test_Whisper_003(t *testing.T) {
|
|||
t.Logf("%s: %f", whisper.Whisper_lang_str(i), p)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Whisper_State_Init_Free(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
state := ctx.Whisper_init_state()
|
||||
assert.NotNil(state)
|
||||
state.Whisper_free_state()
|
||||
}
|
||||
|
||||
func Test_Whisper_Full_With_State(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
state := ctx.Whisper_init_state()
|
||||
assert.NotNil(state)
|
||||
defer state.Whisper_free_state()
|
||||
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
// Run using state
|
||||
err = ctx.Whisper_full_with_state(state, params, data, nil, nil, nil)
|
||||
assert.NoError(err)
|
||||
|
||||
// Validate results are stored in state
|
||||
nSegments := ctx.Whisper_full_n_segments_from_state(state)
|
||||
assert.GreaterOrEqual(nSegments, 1)
|
||||
text := ctx.Whisper_full_get_segment_text_from_state(state, 0)
|
||||
assert.NotEmpty(text)
|
||||
}
|
||||
|
||||
func Test_Whisper_Lang_Auto_Detect_With_State(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = fh.Close() }()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
state := ctx.Whisper_init_state()
|
||||
assert.NotNil(state)
|
||||
defer state.Whisper_free_state()
|
||||
|
||||
threads := runtime.NumCPU()
|
||||
// Prepare mel into state then detect
|
||||
assert.NoError(ctx.Whisper_pcm_to_mel_with_state(state, data, threads))
|
||||
probs, err := ctx.Whisper_lang_auto_detect_with_state(state, 0, threads)
|
||||
assert.NoError(err)
|
||||
assert.Equal(whisper.Whisper_lang_max_id()+1, len(probs))
|
||||
}
|
||||
|
||||
func Test_Whisper_Concurrent_With_State(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, model not found:", ModelPath)
|
||||
}
|
||||
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping test, sample not found:", SamplePath)
|
||||
}
|
||||
|
||||
// Load audio once
|
||||
fh, err := os.Open(SamplePath)
|
||||
assert.NoError(err)
|
||||
defer func() { _ = fh.Close() }()
|
||||
dec := wav.NewDecoder(fh)
|
||||
buf, err := dec.FullPCMBuffer()
|
||||
assert.NoError(err)
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
|
||||
ctx := whisper.Whisper_init(ModelPath)
|
||||
assert.NotNil(ctx)
|
||||
defer ctx.Whisper_free()
|
||||
|
||||
// Each goroutine has its own state
|
||||
state1 := ctx.Whisper_init_state()
|
||||
state2 := ctx.Whisper_init_state()
|
||||
assert.NotNil(state1)
|
||||
assert.NotNil(state2)
|
||||
defer state1.Whisper_free_state()
|
||||
defer state2.Whisper_free_state()
|
||||
|
||||
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex // guard calls into shared ctx, per upstream note not thread-safe for same context
|
||||
errs := make(chan error, 2)
|
||||
|
||||
worker := func(state *whisper.State) {
|
||||
defer wg.Done()
|
||||
mu.Lock()
|
||||
err := ctx.Whisper_full_with_state(state, params, data, nil, nil, nil)
|
||||
if err == nil {
|
||||
n := ctx.Whisper_full_n_segments_from_state(state)
|
||||
if n <= 0 {
|
||||
err = errors.New("no segments")
|
||||
} else {
|
||||
_ = ctx.Whisper_full_get_segment_text_from_state(state, 0)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
errs <- err
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go worker(state1)
|
||||
go worker(state2)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for e := range errs {
|
||||
assert.NoError(e)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue