diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java index 4bcdb6b0..66ec5d70 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java @@ -20,7 +20,7 @@ public class WhisperContextParams extends Structure { /** Use GPU for inference (default = true) */ public CBool use_gpu; - /** Use flash attention (default = false) */ + /** Use flash attention (default = true) */ public CBool flash_attn; /** CUDA device to use (default = 0) */ diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java index bf37e519..e5b22cf8 100644 --- a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.*; import io.github.ggerganov.whispercpp.bean.WhisperSegment; import io.github.ggerganov.whispercpp.params.CBool; +import io.github.ggerganov.whispercpp.params.WhisperContextParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; import org.junit.jupiter.api.BeforeAll; @@ -25,7 +26,9 @@ class WhisperCppTest { //String modelName = "../../models/ggml-tiny.bin"; String modelName = "../../models/ggml-tiny.en.bin"; try { - whisper.initContext(modelName); + WhisperContextParams.ByValue contextParams = whisper.getContextDefaultParams(); + contextParams.useFlashAttn(false); // Disable flash attention + whisper.initContext(modelName, contextParams); //whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); //whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); modelInitialised = true;