go: allow set a flash_attn option in Go binding, better sugar apis
This commit is contained in:
parent
b1f1a14198
commit
7c74def7c9
|
|
@ -1,20 +1,5 @@
|
|||
package whisper
|
||||
|
||||
type (
|
||||
contextParamsOption interface{ apply(*ContextParams) }
|
||||
contextParamsOptionFunc func(*ContextParams)
|
||||
)
|
||||
|
||||
func (fn contextParamsOptionFunc) apply(to *ContextParams) {
|
||||
fn(to)
|
||||
}
|
||||
|
||||
func WithUseGPU(v bool) contextParamsOption {
|
||||
return contextParamsOptionFunc(func(p *ContextParams) {
|
||||
p.SetUseGPU(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ContextParams) UseGPU() bool {
|
||||
return bool(p.use_gpu)
|
||||
}
|
||||
|
|
@ -22,3 +7,11 @@ func (p *ContextParams) UseGPU() bool {
|
|||
func (p *ContextParams) SetUseGPU(v bool) {
|
||||
p.use_gpu = toBool(v)
|
||||
}
|
||||
|
||||
func (p *ContextParams) UseFlashAttention() bool {
|
||||
return bool(p.flash_attn)
|
||||
}
|
||||
|
||||
func (p *ContextParams) SetUseFlashAttention(v bool) {
|
||||
p.flash_attn = toBool(v)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,19 +23,21 @@ var _ Model = (*model)(nil)
|
|||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
func New(path string, options ...modelOption) (Model, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
} else if ctx := whisper.Whisper_init(path); ctx == nil {
|
||||
return nil, ErrUnableToLoadModel
|
||||
} else {
|
||||
model.ctx = ctx
|
||||
model.path = path
|
||||
}
|
||||
|
||||
// Return success
|
||||
return model, nil
|
||||
params := whisper.DefaultContextParams()
|
||||
for _, option := range options {
|
||||
option.apply(¶ms)
|
||||
}
|
||||
|
||||
if ctx := whisper.Whisper_init_with_params(path, params); ctx != nil {
|
||||
return &model{path, ctx}, nil
|
||||
}
|
||||
|
||||
return nil, ErrUnableToLoadModel
|
||||
}
|
||||
|
||||
func (model *model) Close() error {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
package whisper
|
||||
|
||||
import whisper "github.com/ggerganov/whisper.cpp/bindings/go"
|
||||
|
||||
type ContextParams = whisper.ContextParams
|
||||
|
||||
type (
|
||||
modelOption interface{ apply(*ContextParams) }
|
||||
modelOptionFunc func(*ContextParams)
|
||||
)
|
||||
|
||||
func (fn modelOptionFunc) apply(to *ContextParams) {
|
||||
fn(to)
|
||||
}
|
||||
|
||||
func WithUseGPU(v bool) modelOption {
|
||||
return modelOptionFunc(func(p *ContextParams) {
|
||||
p.SetUseGPU(v)
|
||||
})
|
||||
}
|
||||
|
||||
func WithUseFlashAttention(v bool) modelOption {
|
||||
return modelOptionFunc(func(p *ContextParams) {
|
||||
p.SetUseFlashAttention(v)
|
||||
})
|
||||
}
|
||||
|
|
@ -102,13 +102,13 @@ var (
|
|||
|
||||
// Allocates all memory needed for the model and loads the model from the given file.
|
||||
// Returns NULL on failure.
|
||||
func Whisper_init(path string, options ...contextParamsOption) *Context {
|
||||
func Whisper_init(path string) *Context {
|
||||
return Whisper_init_with_params(path, DefaultContextParams())
|
||||
}
|
||||
|
||||
func Whisper_init_with_params(path string, params ContextParams) *Context {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
params := ContextParams(C.whisper_context_default_params())
|
||||
for _, o := range options {
|
||||
o.apply(¶ms)
|
||||
}
|
||||
if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil {
|
||||
return (*Context)(ctx)
|
||||
} else {
|
||||
|
|
@ -116,6 +116,10 @@ func Whisper_init(path string, options ...contextParamsOption) *Context {
|
|||
}
|
||||
}
|
||||
|
||||
func DefaultContextParams() ContextParams {
|
||||
return ContextParams(C.whisper_context_default_params())
|
||||
}
|
||||
|
||||
// Frees all memory allocated by the model.
|
||||
func (ctx *Context) Whisper_free() {
|
||||
C.whisper_free((*C.struct_whisper_context)(ctx))
|
||||
|
|
|
|||
Loading…
Reference in New Issue