This commit is contained in:
M1xA 2026-04-20 14:46:53 +00:00 committed by GitHub
commit b2f43d1886
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 10 deletions

View File

@ -0,0 +1,17 @@
package whisper
func (p *ContextParams) UseGPU() bool {
return bool(p.use_gpu)
}
func (p *ContextParams) SetUseGPU(v bool) {
p.use_gpu = toBool(v)
}
func (p *ContextParams) UseFlashAttention() bool {
return bool(p.flash_attn)
}
func (p *ContextParams) SetUseFlashAttention(v bool) {
p.flash_attn = toBool(v)
}

View File

@ -23,19 +23,21 @@ var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (Model, error) {
model := new(model)
func New(path string, options ...modelOption) (Model, error) {
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
model.path = path
}
// Return success
return model, nil
params := whisper.DefaultContextParams()
for _, option := range options {
option.apply(&params)
}
if ctx := whisper.Whisper_init_with_params(path, params); ctx != nil {
return &model{path, ctx}, nil
}
return nil, ErrUnableToLoadModel
}
func (model *model) Close() error {

View File

@ -0,0 +1,26 @@
package whisper
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
type ContextParams = whisper.ContextParams
type (
modelOption interface{ apply(*ContextParams) }
modelOptionFunc func(*ContextParams)
)
func (fn modelOptionFunc) apply(to *ContextParams) {
fn(to)
}
func WithUseGPU(v bool) modelOption {
return modelOptionFunc(func(p *ContextParams) {
p.SetUseGPU(v)
})
}
func WithUseFlashAttention(v bool) modelOption {
return modelOptionFunc(func(p *ContextParams) {
p.SetUseFlashAttention(v)
})
}

View File

@ -71,6 +71,7 @@ type (
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
ContextParams C.struct_whisper_context_params
)
///////////////////////////////////////////////////////////////////////////////
@ -102,15 +103,23 @@ var (
// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
return Whisper_init_with_params(path, DefaultContextParams())
}
func Whisper_init_with_params(path string, params ContextParams) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil {
return (*Context)(ctx)
} else {
return nil
}
}
func DefaultContextParams() ContextParams {
return ContextParams(C.whisper_context_default_params())
}
// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))