From 8a67c55c8aa17c701494297fc4251187370b1044 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Sep 2025 21:28:03 +0300 Subject: [PATCH 001/104] wchess : fix link [no ci] --- examples/wchess/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wchess/README.md b/examples/wchess/README.md index d9694a1e..3d62651b 100644 --- a/examples/wchess/README.md +++ b/examples/wchess/README.md @@ -2,7 +2,7 @@ Voice-controlled chess using Whisper -Online demo: https://ggml.ai/whisper.cpp/ +Online demo: https://ggml.ai/whisper.cpp/wchess.wasm/ https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa From 47fcd7da8b72432a7c5eada529e9f5f0e0bccf56 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Sep 2025 21:37:00 +0300 Subject: [PATCH 002/104] scripts : add -nfa option [no ci] --- ci/run.sh | 2 +- scripts/bench-all.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index d98a3d86..cbe28442 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -246,7 +246,7 @@ function gg_run_bench { cd ${SRC} # set flash attention flag if enabled - fattn="" + fattn="-nfa" if [ "$BENCH_FLASH_ATTN" -eq 1 ]; then fattn="-fa" fi diff --git a/scripts/bench-all.sh b/scripts/bench-all.sh index 4c1a7a10..a15a361c 100755 --- a/scripts/bench-all.sh +++ b/scripts/bench-all.sh @@ -19,7 +19,7 @@ fi fattn="" if [ -z "$3" ] || [ "$3" -eq 0 ]; then - fattn="" + fattn="-nfa" else fattn="-fa" fi From 8c0855fd6bb115e113c0dca6255ea05f774d35f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Sep 2025 21:40:32 +0300 Subject: [PATCH 003/104] bench : update [no ci] --- scripts/bench-all-gg.txt | 61 ++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index 82bf6aa1..d1cdaf9a 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -45,20 +45,20 @@ Running ggml_mul_mat benchmark with 1 threads | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M1 Pro | METAL | tiny | 1 | 0 | 39.21 | 1.74 | 0.61 | 0.04 | 22c96b4 | -| M1 Pro | METAL | base | 1 | 0 | 70.76 | 2.60 | 0.93 | 0.06 | 22c96b4 | -| M1 Pro | METAL | small | 1 | 0 | 217.28 | 6.42 | 2.14 | 0.17 | 22c96b4 | -| M1 Pro | METAL | medium | 1 | 0 | 596.74 | 14.43 | 4.75 | 0.45 | 22c96b4 | +| M1 Pro | METAL | tiny | 1 | 0 | 32.44 | 1.71 | 0.43 | 0.04 | 8a67c55c | +| M1 Pro | METAL | base | 1 | 0 | 63.54 | 2.62 | 0.71 | 0.06 | 8a67c55c | +| M1 Pro | METAL | small | 1 | 0 | 200.30 | 5.34 | 1.72 | 0.17 | 8a67c55c | +| M1 Pro | METAL | medium | 1 | 0 | 580.06 | 11.71 | 4.18 | 0.45 | 8a67c55c | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M1 Pro | METAL | tiny | 1 | 1 | 21.98 | 1.66 | 0.29 | 0.03 | a77d11d9 | -| M1 Pro | METAL | base | 1 | 1 | 40.55 | 2.18 | 0.43 | 0.04 | a77d11d9 | -| M1 Pro | METAL | small | 1 | 1 | 229.44 | 4.38 | 0.95 | 0.11 | a77d11d9 | -| M1 Pro | METAL | medium | 1 | 1 | 394.64 | 9.11 | 2.21 | 0.30 | a77d11d9 | +| M1 Pro | METAL | tiny | 1 | 1 | 22.09 | 1.84 | 0.43 | 0.03 | 8a67c55c | +| M1 Pro | METAL | base | 1 | 1 | 40.57 | 2.22 | 0.44 | 0.04 | 8a67c55c | +| M1 Pro | METAL | small | 1 | 1 | 135.15 | 4.23 | 0.95 | 0.12 | 8a67c55c | +| M1 Pro | METAL | medium | 1 | 1 | 395.18 | 9.14 | 2.21 | 0.30 | 8a67c55c | ## M2 Ultra @@ -218,33 +218,34 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M4 Max | METAL | tiny | 1 | 0 | 10.46 | 0.81 | 0.22 | 0.01 | b57b9d3a | -| M4 Max | METAL | tiny-q8_0 | 1 | 0 | 10.64 | 0.79 | 0.23 | 0.01 | b57b9d3a | -| M4 Max | METAL | base | 1 | 0 | 19.61 | 1.32 | 0.35 | 0.02 | b57b9d3a | -| M4 Max | METAL | base-q8_0 | 1 | 0 | 20.08 | 1.25 | 0.36 | 0.02 | b57b9d3a | -| M4 Max | METAL | small | 1 | 0 | 62.59 | 2.78 | 0.78 | 0.06 | b57b9d3a | -| M4 Max | METAL | small-q8_0 | 1 | 0 | 64.30 | 2.42 | 0.78 | 0.06 | b57b9d3a | -| M4 Max | METAL | medium | 1 | 0 | 181.55 | 6.42 | 1.84 | 0.15 | b57b9d3a | -| M4 Max | METAL | medium-q8_0 | 1 | 0 | 187.79 | 5.74 | 1.83 | 0.15 | b57b9d3a | -| M4 Max | METAL | large-v2 | 1 | 0 | 335.93 | 10.56 | 3.03 | 0.26 | b57b9d3a | -| M4 Max | METAL | large-v2-q8_0 | 1 | 0 | 350.73 | 8.73 | 2.98 | 0.27 | b57b9d3a | -| M4 Max | METAL | large-v3-turbo | 1 | 0 | 301.98 | 1.82 | 0.49 | 0.04 | b57b9d3a | +| M4 Max | METAL | tiny | 1 | 0 | 10.51 | 0.86 | 0.23 | 0.01 | 47fcd7da | +| M4 Max | METAL | tiny-q8_0 | 1 | 0 | 10.73 | 0.84 | 0.24 | 0.01 | 47fcd7da | +| M4 Max | METAL | base | 1 | 0 | 19.50 | 1.34 | 0.36 | 0.02 | 47fcd7da | +| M4 Max | METAL | base-q8_0 | 1 | 0 | 20.17 | 1.25 | 0.36 | 0.02 | 47fcd7da | +| M4 Max | METAL | small | 1 | 0 | 61.91 | 2.77 | 0.78 | 0.06 | 47fcd7da | +| M4 Max | METAL | small-q8_0 | 1 | 0 | 64.17 | 2.43 | 0.78 | 0.06 | 47fcd7da | +| M4 Max | METAL | medium | 1 | 0 | 181.50 | 6.44 | 1.85 | 0.15 | 47fcd7da | +| M4 Max | METAL | medium-q8_0 | 1 | 0 | 187.71 | 5.80 | 1.84 | 0.15 | 47fcd7da | +| M4 Max | METAL | large-v2 | 1 | 0 | 335.49 | 10.49 | 3.01 | 0.26 | 47fcd7da | +| M4 Max | METAL | large-v2-q8_0 | 1 | 0 | 349.89 | 8.65 | 2.97 | 0.27 | 47fcd7da | +| M4 Max | METAL | large-v3-turbo | 1 | 0 | 301.34 | 1.83 | 0.49 | 0.04 | 47fcd7da | + make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M4 Max | METAL | tiny | 1 | 1 | 8.27 | 0.73 | 0.16 | 0.01 | a77d11d9 | -| M4 Max | METAL | tiny-q8_0 | 1 | 1 | 8.46 | 0.67 | 0.16 | 0.01 | a77d11d9 | -| M4 Max | METAL | base | 1 | 1 | 15.43 | 1.11 | 0.26 | 0.02 | a77d11d9 | -| M4 Max | METAL | base-q8_0 | 1 | 1 | 16.02 | 1.04 | 0.27 | 0.02 | a77d11d9 | -| M4 Max | METAL | small | 1 | 1 | 49.88 | 2.34 | 0.54 | 0.05 | a77d11d9 | -| M4 Max | METAL | small-q8_0 | 1 | 1 | 51.86 | 1.99 | 0.54 | 0.05 | a77d11d9 | -| M4 Max | METAL | medium | 1 | 1 | 148.17 | 5.45 | 1.27 | 0.12 | a77d11d9 | -| M4 Max | METAL | medium-q8_0 | 1 | 1 | 154.43 | 4.56 | 1.25 | 0.13 | a77d11d9 | -| M4 Max | METAL | large-v2 | 1 | 1 | 283.30 | 8.96 | 2.10 | 0.22 | a77d11d9 | -| M4 Max | METAL | large-v2-q8_0 | 1 | 1 | 298.13 | 7.28 | 2.08 | 0.23 | a77d11d9 | -| M4 Max | METAL | large-v3-turbo | 1 | 1 | 250.19 | 1.64 | 0.37 | 0.04 | a77d11d9 | +| M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.71 | 0.16 | 0.01 | 47fcd7da | +| M4 Max | METAL | tiny-q8_0 | 1 | 1 | 8.47 | 0.67 | 0.16 | 0.01 | 47fcd7da | +| M4 Max | METAL | base | 1 | 1 | 15.47 | 1.12 | 0.26 | 0.02 | 47fcd7da | +| M4 Max | METAL | base-q8_0 | 1 | 1 | 15.70 | 1.05 | 0.27 | 0.02 | 47fcd7da | +| M4 Max | METAL | small | 1 | 1 | 49.82 | 2.37 | 0.53 | 0.05 | 47fcd7da | +| M4 Max | METAL | small-q8_0 | 1 | 1 | 51.76 | 1.99 | 0.53 | 0.05 | 47fcd7da | +| M4 Max | METAL | medium | 1 | 1 | 147.76 | 5.52 | 1.27 | 0.12 | 47fcd7da | +| M4 Max | METAL | medium-q8_0 | 1 | 1 | 153.98 | 4.59 | 1.24 | 0.13 | 47fcd7da | +| M4 Max | METAL | large-v2 | 1 | 1 | 282.89 | 9.06 | 2.11 | 0.22 | 47fcd7da | +| M4 Max | METAL | large-v2-q8_0 | 1 | 1 | 296.43 | 7.44 | 2.09 | 0.23 | 47fcd7da | +| M4 Max | METAL | large-v3-turbo | 1 | 1 | 249.91 | 1.65 | 0.38 | 0.04 | 47fcd7da | # RTX 5090 From 2a5686966944a3fbf192678757afd7120d25732f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 1 Oct 2025 09:13:34 +0200 Subject: [PATCH 004/104] bindings-java : disable flash attention by default (#3445) This commit disables flash-attention for the Java binding test so that the testFullTranscribe test passes. Without this change the test was failing because the expected output mismatches after the flash-attention change: ```console but was: ``` An alternative would also be to update the expected output but it felt better to keep the same expected output and disable flash-attention and not just change the expected output to match the new behavior. --- .../ggerganov/whispercpp/params/WhisperContextParams.java | 2 +- .../java/io/github/ggerganov/whispercpp/WhisperCppTest.java | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java index 4bcdb6b0..66ec5d70 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java @@ -20,7 +20,7 @@ public class WhisperContextParams extends Structure { /** Use GPU for inference (default = true) */ public CBool use_gpu; - /** Use flash attention (default = false) */ + /** Use flash attention (default = true) */ public CBool flash_attn; /** CUDA device to use (default = 0) */ diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java index bf37e519..e5b22cf8 100644 --- a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.*; import io.github.ggerganov.whispercpp.bean.WhisperSegment; import io.github.ggerganov.whispercpp.params.CBool; +import io.github.ggerganov.whispercpp.params.WhisperContextParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; import org.junit.jupiter.api.BeforeAll; @@ -25,7 +26,9 @@ class WhisperCppTest { //String modelName = "../../models/ggml-tiny.bin"; String modelName = "../../models/ggml-tiny.en.bin"; try { - whisper.initContext(modelName); + WhisperContextParams.ByValue contextParams = whisper.getContextDefaultParams(); + contextParams.useFlashAttn(false); // Disable flash attention + whisper.initContext(modelName, contextParams); //whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); //whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); modelInitialised = true; From 7849aff7a2e1f4234aa31b01a1870906d5431959 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Wed, 1 Oct 2025 21:33:11 +0900 Subject: [PATCH 005/104] ruby : Loose RegExp for test (#3448) --- bindings/ruby/test/test_whisper.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 12b82a8d..23479b7a 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -34,7 +34,7 @@ class TestWhisper < TestBase params = Whisper::Params.new @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| - assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text) + assert_match(/what you can do for your country/i, text) } end From c8223a8548ad64435266e551385fc51aca9ee8ab Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 6 Oct 2025 14:57:44 +0200 Subject: [PATCH 006/104] vad : fix memory leaks in VAD implementation (#3453) * vad : fix memory leak by storing ggml_context in vad context struct This commit addresses a memory leak issue in the voice activity detection (VAD) where the ggml_context is not stored within the vad context structure. The motivation for this change that this is causing the context memory to stay allocated and the tensor still point to that memory but this memory is never freed. * vad : free memory allocated for VAD hparams This commit frees the model hyperparameters allocated for the VAD context in the `whisper_vad_free` function. Specifically, it deletes the `encoder_in_channels`, `encoder_out_channels`, and `kernel_sizes` arrays allocated with `new[]` in the `whisper_vad_init` function. The motivation for this is to prevent memory leaks when the VAD. * vad: free ggml buffer in whisper_vad_free This commit frees the ggml buffer in the whisper_vad_free function to prevent memory leaks. Resolves: https://github.com/ggml-org/whisper.cpp/issues/3452 * Revert "vad : fix memory leak by storing ggml_context in vad context struct" This reverts commit aeafca437efa7fb28166703f845e321176aa62ab. * whisper : free ggml context in whisper_vad_init_context This commit frees the ggml_context after initializing the VAD context in the whisper_vad_init_context function. The motivation for this is to prevent memory leaks. --- src/whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/whisper.cpp b/src/whisper.cpp index d99dd7be..39c53ba2 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -4676,6 +4676,7 @@ static bool whisper_vad_init_context(whisper_vad_context * vctx) { ggml_set_name(vctx->c_state, "c_state"); vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]); + ggml_free(ctx); if (!vctx->buffer) { WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__); return false; @@ -5420,6 +5421,9 @@ struct whisper_vad_segments * whisper_vad_segments_from_samples( void whisper_vad_free(whisper_vad_context * ctx) { if (ctx) { + if (ctx->buffer) { + ggml_backend_buffer_free(ctx->buffer); + } for (ggml_context * context : ctx->model.ctxs) { ggml_free(context); } @@ -5434,6 +5438,9 @@ void whisper_vad_free(whisper_vad_context * ctx) { ggml_backend_free(backend); } + delete[] ctx->model.hparams.encoder_in_channels; + delete[] ctx->model.hparams.encoder_out_channels; + delete[] ctx->model.hparams.kernel_sizes; delete ctx; } From 8877dfc11a9322ce1990958494cf2e41c54657eb Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Wed, 8 Oct 2025 20:45:20 +0900 Subject: [PATCH 007/104] [skip ci]Bump Ruby bindings' version to 1.3.4 (#3461) --- bindings/ruby/whispercpp.gemspec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index c6e88dff..eac35b8a 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.3' + s.version = '1.3.4' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] From 98930fded1c06e601a38903607af262f04893880 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 10:48:40 +0300 Subject: [PATCH 008/104] whisper : clean-up headers --- src/whisper.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 39c53ba2..a49eb59a 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -21,14 +21,12 @@ #define _USE_MATH_DEFINES #include #include -#include #include #include #include #include #include #include -#include #include #include #include @@ -36,6 +34,10 @@ #include #include +#ifdef _MSC_VER +#include +#endif + #if defined(WHISPER_BIG_ENDIAN) template static T byteswap(T value) { From 85d1d3d3dcd6e95944920ddb7ef30a016f6c5b22 Mon Sep 17 00:00:00 2001 From: Silviu Caragea Date: Fri, 10 Oct 2025 04:20:21 +0000 Subject: [PATCH 009/104] vad : free vad_segments in whisper_vad (#3463) This commit fixes multiple issues: * memory leak because vad_segments is never released * avoid segmentation fault when whisper_vad_segments_from_samples returns nullptr. * avoid potential segmentation fault when the app fails to allocate memory for filtered samples and the vad context is released but also get released withing state itself when whisper_free_state is called --- src/whisper.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index a49eb59a..8992a144 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6620,6 +6620,9 @@ static bool whisper_vad( whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); + if(!vad_segments) + return false; + if (vad_segments->data.size() > 0) { state->has_vad_segments = true; ctx->state->vad_segments.clear(); @@ -6662,7 +6665,6 @@ static bool whisper_vad( } catch (const std::bad_alloc & /* e */) { WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__); whisper_vad_free_segments(vad_segments); - whisper_vad_free(vctx); return false; } @@ -6768,6 +6770,7 @@ static bool whisper_vad( __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); } + whisper_vad_free_segments(vad_segments); return true; } From d3a29d7b882ae818dceae27f22555175ad9048b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 11:33:01 +0300 Subject: [PATCH 010/104] minor : fix code style (#3463) --- src/whisper.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 8992a144..a212b7c9 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6620,8 +6620,9 @@ static bool whisper_vad( whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); - if(!vad_segments) + if (!vad_segments) { return false; + } if (vad_segments->data.size() > 0) { state->has_vad_segments = true; From a0ca50f3b948515a589aed43b3bdd334a1ded1bf Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Fri, 10 Oct 2025 15:21:03 +0200 Subject: [PATCH 011/104] cli: Fix assignment for vad_min_silence_duration_ms (#3467) * cli: Fix assignment for vad_min_silence_duration_ms Found and fixed this simple copy/paste error * server : fix vad_min_silence_duration_ms assignment --------- Co-authored-by: Daniel Bevenius --- examples/cli/cli.cpp | 2 +- examples/server/server.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 457a1ff3..0739cacf 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -204,7 +204,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); } else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(ARGV_NEXT); } else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); } else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); } else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fd9b7784..1262c3d6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -256,7 +256,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } From 85871a946971955c635f56bca24ea2a37fed6324 Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Fri, 10 Oct 2025 18:51:15 +0200 Subject: [PATCH 012/104] whisper : add support for --carry-initial-prompt (#3395) * Add support for --carry-initial-prompt * PR fixes for ruby and go * Refactoring for readability * WIP 1 * WIP 2 * PR fixes * More PR fixes * PR fix * Further simplification * d'oh * One more logic fix * Update src/whisper.cpp Co-authored-by: Georgi Gerganov * Truncate prompt_past0 upon initialization * Slight simplification --------- Co-authored-by: Georgi Gerganov --- bindings/go/params.go | 8 + .../whispercpp/params/WhisperFullParams.java | 6 +- bindings/ruby/ext/ruby_whisper_params.c | 69 ++++-- bindings/ruby/sig/whisper.rbs | 3 + bindings/ruby/test/test_params.rb | 8 + examples/cli/cli.cpp | 231 +++++++++--------- include/whisper.h | 1 + src/whisper.cpp | 93 +++++-- 8 files changed, 257 insertions(+), 162 deletions(-) diff --git a/bindings/go/params.go b/bindings/go/params.go index 95c5bfaf..d8dee57e 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) { p.print_timestamps = toBool(v) } + // Set language id func (p *Params) SetLanguage(lang int) error { if lang == -1 { @@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) { p.initial_prompt = C.CString(prompt) } +func (p *Params) SetCarryInitialPrompt(v bool) { + p.carry_initial_prompt = toBool(v) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS @@ -199,6 +204,9 @@ func (p *Params) String() string { if p.token_timestamps { str += " token_timestamps" } + if p.carry_initial_prompt { + str += " carry_initial_prompt" + } return str + ">" } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 498ff126..76ce80fb 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -157,6 +157,8 @@ public class WhisperFullParams extends Structure { /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; + /** Always prepend initial_prompt for every decode chunk. */ + public CBool carry_initial_prompt; /** Prompt tokens. (int*) */ public Pointer prompt_tokens; @@ -336,8 +338,8 @@ public class WhisperFullParams extends Structure { "no_timestamps", "single_segment", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", - "split_on_word", "max_tokens", "debug_mode", "audio_ctx", - "tdrz_enable", "suppress_regex", "initial_prompt", + "split_on_word", "max_tokens", "debug_mode", "audio_ctx", + "tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 882c68d0..70417cb1 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -26,7 +26,7 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37 extern VALUE cParams; extern VALUE cVADParams; @@ -46,6 +46,7 @@ static ID id_print_special; static ID id_print_progress; static ID id_print_realtime; static ID id_print_timestamps; +static ID id_carry_initial_prompt; static ID id_suppress_blank; static ID id_suppress_nst; static ID id_token_timestamps; @@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, print_timestamps) } + +/* + * call-seq: + * carry_initial_prompt -> true or false + */ +static VALUE +ruby_whisper_params_get_carry_initial_prompt(VALUE self) +{ + BOOL_PARAMS_GETTER(self, carry_initial_prompt) +} + +/* + * call-seq: + * carry_initial_prompt = bool -> bool + */ +static VALUE +ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, carry_initial_prompt, value) +} /* * call-seq: * suppress_blank = force_suppress -> force_suppress @@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(max_len) SET_PARAM_IF_SAME(split_on_word) SET_PARAM_IF_SAME(initial_prompt) + SET_PARAM_IF_SAME(carry_initial_prompt) SET_PARAM_IF_SAME(offset) SET_PARAM_IF_SAME(duration) SET_PARAM_IF_SAME(max_text_tokens) @@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(max_len, 11) DEFINE_PARAM(split_on_word, 12) DEFINE_PARAM(initial_prompt, 13) - DEFINE_PARAM(diarize, 14) - DEFINE_PARAM(offset, 15) - DEFINE_PARAM(duration, 16) - DEFINE_PARAM(max_text_tokens, 17) - DEFINE_PARAM(temperature, 18) - DEFINE_PARAM(max_initial_ts, 19) - DEFINE_PARAM(length_penalty, 20) - DEFINE_PARAM(temperature_inc, 21) - DEFINE_PARAM(entropy_thold, 22) - DEFINE_PARAM(logprob_thold, 23) - DEFINE_PARAM(no_speech_thold, 24) - DEFINE_PARAM(new_segment_callback, 25) - DEFINE_PARAM(new_segment_callback_user_data, 26) - DEFINE_PARAM(progress_callback, 27) - DEFINE_PARAM(progress_callback_user_data, 28) - DEFINE_PARAM(encoder_begin_callback, 29) - DEFINE_PARAM(encoder_begin_callback_user_data, 30) - DEFINE_PARAM(abort_callback, 31) - DEFINE_PARAM(abort_callback_user_data, 32) - DEFINE_PARAM(vad, 33) - DEFINE_PARAM(vad_model_path, 34) - DEFINE_PARAM(vad_params, 35) + DEFINE_PARAM(carry_initial_prompt, 14) + DEFINE_PARAM(diarize, 15) + DEFINE_PARAM(offset, 16) + DEFINE_PARAM(duration, 17) + DEFINE_PARAM(max_text_tokens, 18) + DEFINE_PARAM(temperature, 19) + DEFINE_PARAM(max_initial_ts, 20) + DEFINE_PARAM(length_penalty, 21) + DEFINE_PARAM(temperature_inc, 22) + DEFINE_PARAM(entropy_thold, 23) + DEFINE_PARAM(logprob_thold, 24) + DEFINE_PARAM(no_speech_thold, 25) + DEFINE_PARAM(new_segment_callback, 26) + DEFINE_PARAM(new_segment_callback_user_data, 27) + DEFINE_PARAM(progress_callback, 28) + DEFINE_PARAM(progress_callback_user_data, 29) + DEFINE_PARAM(encoder_begin_callback, 30) + DEFINE_PARAM(encoder_begin_callback_user_data, 31) + DEFINE_PARAM(abort_callback, 32) + DEFINE_PARAM(abort_callback_user_data, 33) + DEFINE_PARAM(vad, 34) + DEFINE_PARAM(vad_model_path, 35) + DEFINE_PARAM(vad_params, 36) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 0489432a..d5905dd7 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -138,6 +138,7 @@ module Whisper ?max_len: Integer, ?split_on_word: boolish, ?initial_prompt: string | nil, + ?carry_initial_prompt: boolish, ?diarize: boolish, ?offset: Integer, ?duration: Integer, @@ -236,6 +237,7 @@ module Whisper def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS + def carry_initial_prompt=: (boolish) -> boolish # Tokens to provide to the whisper decoder as initial prompt # these are prepended to any existing text context from a previous call @@ -243,6 +245,7 @@ module Whisper # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) + def carry_initial_prompt: () -> (true | false) def diarize=: (boolish) -> boolish diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index d5c5d140..4dd9780d 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -16,6 +16,7 @@ class TestParams < TestBase :max_len, :split_on_word, :initial_prompt, + :carry_initial_prompt, :diarize, :offset, :duration, @@ -119,6 +120,13 @@ class TestParams < TestBase assert !@params.print_timestamps end + def test_carry_initial_prompt + @params.carry_initial_prompt = true + assert @params.carry_initial_prompt + @params.carry_initial_prompt = false + assert !@params.carry_initial_prompt + end + def test_suppress_blank @params.suppress_blank = true assert @params.suppress_blank diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 0739cacf..9a54742f 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -5,6 +5,7 @@ #include "grammar-parser.h" #include +#include #include #include #include @@ -77,6 +78,7 @@ struct whisper_params { bool use_gpu = true; bool flash_attn = true; bool suppress_nst = false; + bool carry_initial_prompt = false; std::string language = "en"; std::string prompt; @@ -145,60 +147,61 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params exit(0); } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); } - else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); } - else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); } - else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if ( arg == "--print-confidence"){ params.print_confidence= true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); } - else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } - else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } - else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } - else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; } - else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } - else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } - else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-ojf" || arg == "--output-json-full") { params.output_jsn_full = params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if ( arg == "--print-confidence") { params.print_confidence= true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } + else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } + else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; } + else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } + else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } // Voice Activity Detection (VAD) else if ( arg == "--vad") { params.vad = true; } else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } @@ -224,61 +227,62 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); - fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); - fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); - fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); - fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", ""); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); - fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); - fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); - fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); - fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); - fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); - fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); - fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); - fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); + fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); + fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); + fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false"); @@ -387,7 +391,11 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct const char * text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); + const int n_colors = (int) k_colors.size(); + int raw_col = (int) (std::pow(p, 3)*float(n_colors)); + if (raw_col < 0) raw_col = 0; + if (raw_col > n_colors - 1) raw_col = n_colors - 1; + const int col = raw_col; printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } @@ -1178,7 +1186,8 @@ int main(int argc, char ** argv) { wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); - wparams.initial_prompt = params.prompt.c_str(); + wparams.initial_prompt = params.prompt.c_str(); + wparams.carry_initial_prompt = params.carry_initial_prompt; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/include/whisper.h b/include/whisper.h index fcd756a9..f4cc6bf7 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -525,6 +525,7 @@ extern "C" { // use whisper_tokenize() to convert text to tokens // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) const char * initial_prompt; + bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text) const whisper_token * prompt_tokens; int prompt_n_tokens; diff --git a/src/whisper.cpp b/src/whisper.cpp index a212b7c9..18874309 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -140,6 +140,10 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } while (0) #define WHISPER_MAX_DECODERS 8 + +// temperature below which we condition on past text history +static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f; + #define WHISPER_MAX_NODES 4096 static std::string format(const char * fmt, ...) { @@ -882,7 +886,10 @@ struct whisper_state { std::vector logits; std::vector result_all; - std::vector prompt_past; + + // prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1) + std::vector prompt_past0; // static carried initial prompt (if enabled) + std::vector prompt_past1; // dynamic context from decoded output int lang_id = 0; // english by default @@ -5922,9 +5929,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /* suppress_regex =*/ nullptr, - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.initial_prompt =*/ nullptr, + /*.carry_initial_prompt =*/ false, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, /*.language =*/ "en", /*.detect_language =*/ false, @@ -6880,17 +6888,22 @@ int whisper_full_with_state( decoder.rng = std::mt19937(j); } - // the accumulated text context so far - auto & prompt_past = state->prompt_past; + // the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1) + auto & prompt_past0 = state->prompt_past0; + auto & prompt_past1 = state->prompt_past1; if (params.no_context) { - prompt_past.clear(); + prompt_past0.clear(); + prompt_past1.clear(); } + // calculate the maximum context budget for prompt history + const int max_prompt_ctx = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2); + // prepare prompt { std::vector prompt_tokens; - // initial prompt + // tokenize the initial prompt if (!params.prompt_tokens && params.initial_prompt) { prompt_tokens.resize(1024); int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); @@ -6902,14 +6915,25 @@ int whisper_full_with_state( params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); } - - // prepend the prompt tokens to the prompt_past if (params.prompt_tokens && params.prompt_n_tokens > 0) { - // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { - prompt_past.push_back(params.prompt_tokens[i]); + if (params.carry_initial_prompt) { + if (prompt_past0.empty()) { + const int max_tokens = std::max(1, max_prompt_ctx - 1); + + if (params.prompt_n_tokens > max_tokens) { + WHISPER_LOG_WARN("%s: initial prompt is too long (%d tokens), will use only the last %d tokens\n", + __func__, params.prompt_n_tokens, max_tokens); + } + + const int n_tokens = std::min(params.prompt_n_tokens, max_tokens); + prompt_past0.assign(params.prompt_tokens + (params.prompt_n_tokens - n_tokens), params.prompt_tokens + params.prompt_n_tokens); + } + } else { + for (int i = 0; i < params.prompt_n_tokens; ++i) { + prompt_past1.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end()); } - std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } } @@ -6995,7 +7019,8 @@ int whisper_full_with_state( // if there is a very short audio segment left to process, we remove any past prompt since it tends // to confuse the decoder and often make it repeat or hallucinate stuff if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); + prompt_past0.clear(); + prompt_past1.clear(); } int best_decoder_id = 0; @@ -7056,12 +7081,25 @@ int whisper_full_with_state( { prompt.clear(); - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + if (params.n_max_text_ctx > 0 && t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) { + const bool can_take0 = params.carry_initial_prompt && !prompt_past0.empty(); + const bool can_take1 = !prompt_past1.empty(); - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + if (max_prompt_ctx > 0 && (can_take0 || can_take1)) { + // Always start with previous token marker to connect continuity + prompt.push_back(whisper_token_prev(ctx)); + + // Take static tokens (initial prompt) first + int n_take0 = 0; + if (can_take0) { + n_take0 = prompt_past0.size(); + prompt.insert(prompt.end(), prompt_past0.end() - n_take0, prompt_past0.end()); + } + + // Fill remaining budget with dynamic tokens (rolling context) + const int n_take1 = std::min(max_prompt_ctx - n_take0 - 1, prompt_past1.size()); + prompt.insert(prompt.end(), prompt_past1.end() - n_take1, prompt_past1.end()); + } } // init new transcription with sot, language (opt) and task tokens @@ -7543,14 +7581,17 @@ int whisper_full_with_state( //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); - // update prompt_past - prompt_past.clear(); - if (prompt.front() == whisper_token_prev(ctx)) { - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + // update prompt_past1 + prompt_past1.clear(); + if (!params.carry_initial_prompt && !prompt.empty() && prompt.front() == whisper_token_prev(ctx)) { + prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); } - for (int i = 0; i < result_len && !is_no_speech; ++i) { - prompt_past.push_back(tokens_cur[i].id); + // Add newly decoded tokens to the rolling context + if (!is_no_speech) { + for (int i = 0; i < result_len; ++i) { + prompt_past1.push_back(tokens_cur[i].id); + } } if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) { From c3b5c4d9349f4353dd8620fa621a4386ab8812a6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 11 Oct 2025 16:55:16 +0200 Subject: [PATCH 013/104] whisper : Support using devices of type iGPU (#3469) --- src/whisper.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 18874309..33e556c4 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1296,7 +1296,7 @@ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & pa if (params.use_gpu) { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) { + if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_IGPU) { if (cnt == params.gpu_device) { dev = dev_cur; } @@ -1365,7 +1365,7 @@ static buft_list_t make_buft_list(whisper_context_params & params) { int cnt = 0; for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { if (cnt == params.gpu_device) { auto * buft = ggml_backend_dev_buffer_type(dev); if (buft) { @@ -1403,6 +1403,7 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * bool op_supported = true; if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { // GPU and default CPU backend support all operators op_supported = true; @@ -4455,6 +4456,7 @@ static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tens bool op_supported = true; if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { // GPU and default CPU backend support all operators op_supported = true; From 199626d79e4265b652b51c2dd3ca8a480d01d684 Mon Sep 17 00:00:00 2001 From: lhez Date: Tue, 30 Sep 2025 09:55:13 -0700 Subject: [PATCH 014/104] opencl: support ne3 in get_rows (llama/15866) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 39 +++++++++++-------- ggml/src/ggml-opencl/kernels/get_rows.cl | 48 ++++++++++++++++++------ 2 files changed, 59 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0cf3b924..a9405ab0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4222,15 +4222,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const int ne10 = src1 ? src1->ne[0] : 0; - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb1 = dst ? dst->nb[1] : 0; - const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const int ne00 = src0->ne[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + const int ne10 = src1->ne[0]; + const cl_ulong nb10 = src1->nb[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -4267,14 +4271,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; - size_t local_work_size[] = {1, 1, 1}; + size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index b3fea292..c2962edc 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { float16 temp; + if (ind >= ne00) { + return; + } dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp; } } From 8208cea829902016dae08193022060f39823aae1 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 30 Sep 2025 09:57:51 -0700 Subject: [PATCH 015/104] ggml webgpu: support for rope,div,sub,glu,scale,cont operators (llama/16187) * Work on rope * Simplify inplace operation generation and combine mul/add generation * Work on rope variants * implement neox rope * rope complete * Add sub,div,glu operators * implement scale op * Update cpy shader to handle cont/more types * formatting * Update test vars printing for rope,rms_norm * Avoid ROPE hardcoded constants * Add TODO to change ROPE constants to enum Co-authored-by: Georgi Gerganov * fix TODO comment --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 2 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 488 +++++++++++++++--- .../ggml-webgpu/wgsl-shaders/add.tmpl.wgsl | 44 -- .../wgsl-shaders/add_in_place.tmpl.wgsl | 41 -- .../ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl | 188 +++++++ .../ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 101 ++++ ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 60 --- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 17 +- .../wgsl-shaders/get_rows.tmpl.wgsl | 2 +- .../ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl | 323 ++++++++++++ .../ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl | 44 -- .../wgsl-shaders/mul_in_place.tmpl.wgsl | 41 -- .../ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 57 +- .../wgsl-shaders/rms_norm_in_place.wgsl | 48 -- .../ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl | 282 ++++++++++ .../ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl | 90 ++++ 16 files changed, 1461 insertions(+), 367 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 36b23dc6..5028a9ce 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -237,6 +237,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cee4b083..93200a4d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -130,13 +130,15 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline; - wgpu::ComputePipeline add_pipeline[2]; - wgpu::ComputePipeline add_ip_pipeline[2]; - wgpu::ComputePipeline mul_pipeline[2]; - wgpu::ComputePipeline mul_ip_pipeline[2]; - wgpu::ComputePipeline rms_norm_pipeline; - wgpu::ComputePipeline rms_norm_ip_pipeline; + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace size_t memset_bytes_per_thread; @@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shape — same for both tensors even if permuted - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; std::vector entries = { @@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, + ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst, wgpu::ComputePipeline & pipeline, - bool in_place) { + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, src1), .size = ggml_webgpu_tensor_binding_size(ctx, src1) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, } static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool in_place = ggml_webgpu_tensor_equal(src, dst); - - uint32_t eps; - memcpy(&eps, dst->op_params, sizeof(float)); + int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader }; - if (!in_place) { - params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); - if (!in_place) { - params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) src->ne[0]); - params.push_back((uint32_t) src->ne[1]); - params.push_back((uint32_t) src->ne[2]); - params.push_back((uint32_t) src->ne[3]); - params.push_back(eps); // epsilon, will be bitcast to float in shader std::vector entries = { { .binding = 0, @@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline; - if (in_place) { - pipeline = ctx->rms_norm_ip_pipeline; - } else { - pipeline = ctx->rms_norm_pipeline; - } size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_freq_factor = (src2 != nullptr); + + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + int sections[4]; + memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int)); + + float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(src0) / 2, + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) n_dims, + (uint32_t) mode, + *(uint32_t *) &theta_scale, + *(uint32_t *) &attn_factor, + *(uint32_t *) &freq_scale, + *(uint32_t *) &ext_factor, + *(uint32_t *) &corr_dims[0], + *(uint32_t *) &corr_dims[1], + (uint32_t) sections[0], + (uint32_t) sections[1], + (uint32_t) sections[2], + (uint32_t) sections[3] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + }; + uint32_t dst_binding = 2; + if (has_freq_factor) { + dst_binding = 3; + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + } + if (!inplace) { + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } +static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const int split = (src1 != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai + *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + }; + uint32_t dst_binding = 1; + if (split) { + dst_binding = 2; + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + } + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + int inplace = ggml_webgpu_tensor_equal(src, dst); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &dst->op_params[1] // bias + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!inplace) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; switch (node->op) { // no-ops @@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_RESHAPE: return false; case GGML_OP_CPY: + case GGML_OP_CONT: ggml_webgpu_cpy(ctx, src0, node); break; case GGML_OP_SET_ROWS: @@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_webgpu_mul_mat(ctx, src0, src1, node); break; case GGML_OP_ADD: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_SUB: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_MUL: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_DIV: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_RMS_NORM: ggml_webgpu_rms_norm(ctx, src0, node); break; + case GGML_OP_ROPE: + ggml_webgpu_rope(ctx, src0, src1, src2, node); + break; + case GGML_OP_GLU: + ggml_webgpu_glu(ctx, src0, src1, node); + break; + case GGML_OP_SCALE: + ggml_webgpu_scale(ctx, src0, node); + break; default: return false; } @@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], + wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, - "add_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, - "add_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, + "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, + "add_f16_inplace", constants); +} + +static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, + "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, + "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, - "mul_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, - "mul_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, + "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, + "mul_f16_inplace", constants); +} + +static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, + "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, + "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, - "rms_norm_in_place", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, + "rms_norm_inplace", constants); +} + +static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, + "rope_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], + wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, + "rope_f32_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], + wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, + "rope_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], + wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, + "rope_f16_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], + wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); +} + +static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + // reglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], + wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], + wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], + wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], + wgsl_reglu_f16_split, "reglu_f16_split", constants); + // geglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], + wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], + wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], + wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], + wgsl_geglu_f16_split, "geglu_f16_split", constants); + // swiglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], + wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], + wgsl_swiglu_f16, "swiglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], + wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], + wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + // swiglu_oai + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], + wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], + wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + // geglu_erf + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], + wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], + wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], + wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], + wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + // geglu_quick + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], + wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], + wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], + wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], + wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); +} + +static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || @@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = true; break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && - (op->src[1]->type == op->type); + case GGML_OP_DIV: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + (src1->type == op->type); break; case GGML_OP_CPY: + case GGML_OP_CONT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SET_ROWS: supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: - if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || - op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || + ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); } break; case GGML_OP_MUL_MAT: { - switch (op->src[1]->type) { + switch (src1->type) { case GGML_TYPE_F16: - supports_op = (op->src[0]->type == GGML_TYPE_F16); + supports_op |= (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: - switch (op->src[0]->type) { + switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: - supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ROPE: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_GLU_OP_SWIGLU_OAI: + supports_op = op->type == GGML_TYPE_F32; + break; + default: + break; + } + break; + case GGML_OP_SCALE: + supports_op = op->type == GGML_TYPE_F32; break; default: break; @@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_sub_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_div_pipeline(ctx); ggml_webgpu_init_rms_norm_pipeline(ctx); + ggml_webgpu_init_rope_pipeline(ctx); + ggml_webgpu_init_glu_pipeline(ctx); + ggml_webgpu_init_scale_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl deleted file mode 100644 index f261cbb5..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl deleted file mode 100644 index 903f7bdb..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl new file mode 100644 index 00000000..1ce4d83f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl @@ -0,0 +1,188 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "add_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "add_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl new file mode 100644 index 00000000..db1aa349 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl @@ -0,0 +1,101 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f32" + } + }, + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f32" + } + } +] + +#end(VARIANTS) + +#define(SHADER) +enable f16; + +@group(0) @binding(0) +var src: array<{{SRC_TYPE}}>; + +@group(0) @binding(1) +var dst: array<{{DST_TYPE}}>; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl deleted file mode 100644 index 6fe924c5..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ /dev/null @@ -1,60 +0,0 @@ -enable f16; - -@group(0) @binding(0) -var src: array; - -@group(0) @binding(1) -var dst: array; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shape (same for both tensors) - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 + - i2 * params.stride_dst2 + i3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]); -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d9dfd7d6..251051ea 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile): raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - shader_variant = replace_placeholders(shader_template, variant["REPLS"]) - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) - if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: + if "SHADER_NAME" in variant: + output_name = variant["SHADER_NAME"] + elif "SHADER_SUFFIX" in variant: + output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] + elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "TYPE_SUFFIX" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] - elif "TYPE" in variant["REPLS"]: + elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) + elif "REPLS" in variant and "TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl index e3fe311b..f80ce1fc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl @@ -2,9 +2,9 @@ [ { + "SHADER_SUFFIX": "f32_vec", "REPLS": { "TYPE" : "vec4", - "TYPE_SUFFIX": "f32_vec", "DST_TYPE": "vec4", "BLOCK_SIZE": 4 }, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl new file mode 100644 index 00000000..03fcd548 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl @@ -0,0 +1,323 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "reglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "geglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "swiglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_oai_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "swiglu_oai_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "geglu_erf_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_quick_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(REGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return max(a, 0) * b; +} +#enddecl(REGLU) + +#decl(GEGLU) +const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; +const GELU_COEF_A: {{TYPE}} = 0.044715; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; +} +#enddecl(GEGLU) + +#decl(SWIGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a / (1.0 + exp(-a)) * b; +} +#enddecl(SWIGLU) + +#decl(SWIGLU_OAI) +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#enddecl(SWIGLU_OAI) + +#decl(GEGLU_ERF) +const p_erf: {{TYPE}} = 0.3275911; +const a1_erf: {{TYPE}} = 0.254829592; +const a2_erf: {{TYPE}} = -0.284496736; +const a3_erf: {{TYPE}} = 1.421413741; +const a4_erf: {{TYPE}} = -1.453152027; +const a5_erf: {{TYPE}} = 1.061405429; +const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#enddecl(GEGLU_ERF) + +#decl(GEGLU_QUICK) +const GELU_QUICK_COEF: {{TYPE}} = -1.702; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#enddecl(GEGLU_QUICK) + +#decl(NO_SPLIT) +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} +#enddecl(NO_SPLIT) + +#decl(SPLIT) +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + return src0[base]; +} + +fn b_value(base: u32) -> {{TYPE}} { + return src1[base]; +} +#enddecl(SPLIT) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl deleted file mode 100644 index 12506e14..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl deleted file mode 100644 index e467e59e..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index f919a513..a275eeb9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -1,9 +1,48 @@ -@group(0) @binding(0) -var src: array; +#define(VARIANTS) + +[ + { + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_SUFFIX": "inplace", + "DECLS": ["INPLACE"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} @group(0) @binding(1) var dst: array; +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + struct Params { offset_src: u32, // in elements offset_dst: u32, // in elements @@ -23,11 +62,13 @@ struct Params { ne2: u32, ne3: u32, - eps: u32 + eps: f32 }; -@group(0) @binding(2) -var params: Params; +@group(0) @binding(0) +var src: array; + +DECLS override wg_size: u32; @compute @workgroup_size(wg_size) @@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var j: u32 = 0; j < params.ne0; j++) { sum += src[i_src_row + j] * src[i_src_row + j]; } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); for (var j: u32 = 0; j < params.ne0; j++) { - dst[i_dst_row + j] = scale * src[i_src_row + j]; + update(i_src_row + j, i_dst_row + j, scale); } } +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl deleted file mode 100644 index ae84f556..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -@group(0) @binding(0) -var a: array; - -struct Params { - offset: u32, // in elements - - // Strides (in elements) - stride1: u32, - stride2: u32, - stride3: u32, - - // Shape - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: u32 -}; - -@group(0) @binding(1) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } - - // one thread per row - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; - - var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += a[i_row + j] * a[i_row + j]; - } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); - for (var j: u32 = 0; j < params.ne0; j++) { - a[i_row + j] = scale * a[i_row + j]; - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl new file mode 100644 index 00000000..9a6ff411 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl @@ -0,0 +1,282 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f32_ff", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_ff_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f16_ff", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_ff_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(ROTATE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = {{TYPE}}(out0); + dst[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE) + +#decl(ROTATE_INPLACE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = {{TYPE}}(out0); + src0[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE_INPLACE) + +#decl(NO_FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#enddecl(NO_FF_FUNC) + +#decl(FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} +#enddecl(FF_FUNC) + +#decl(NO_FF_BINDINGS) + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NO_FF_BINDINGS) + +#decl(NO_FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var params: Params; + +#enddecl(NO_FF_BINDINGS_INPLACE) + +#decl(FF_BINDINGS) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array<{{TYPE}}>; + +@group(0) @binding(4) +var params: Params; + +#enddecl(FF_BINDINGS) + +#decl(FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#enddecl(FF_BINDINGS_INPLACE) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array; + +DECLS + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per thread + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl new file mode 100644 index 00000000..040e80df --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl @@ -0,0 +1,90 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "scale_f32", + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "scale_f32_inplace", + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + dst[offset] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +@group(0) @binding(1) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + +struct Params { + offset_src: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + scale: f32, + bias: f32 +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + store_scale(src[i_src] * params.scale + params.bias, i_dst); +} +#end(SHADER) From 31bb8699295bd6dde32497d52a282dc5ae69e017 Mon Sep 17 00:00:00 2001 From: lhez Date: Tue, 30 Sep 2025 10:45:45 -0700 Subject: [PATCH 016/104] opencl: support pad_ext (llama/15888) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 67 +++++++++++++++++++++------- ggml/src/ggml-opencl/kernels/pad.cl | 49 +++++++++++--------- 2 files changed, 80 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a9405ab0..79d21487 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1 && - (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t GGML_ASSERT(dst->extra); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t const int s_ne0 = src0->ne[0]; const int s_ne1 = src0->ne[1]; const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; + + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; const int d_ne0 = dst->ne[0]; const int d_ne1 = dst->ne[1]; const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; cl_kernel kernel = backend_ctx->kernel_pad; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); size_t lws0 = 64; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl index 747fa7fe..31fb7ccd 100644 --- a/ggml/src/ggml-opencl/kernels/pad.cl +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -1,30 +1,39 @@ kernel void kernel_pad( - global const void * src0_ptr, - ulong src0_offset, - global void * dst_ptr, - ulong dst_offset, - int s_ne0, int s_ne1, int s_ne2, - int d_ne0, int d_ne1, int d_ne2 + global void * src0, + ulong offset0, + global void * dst, + ulong offsetd, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne0, int ne1, int ne2, int ne3, + ulong nb0, ulong nb1, ulong nb2, ulong nb3, + int lp0, int rp0, + int lp1, int rp1, + int lp2, int rp2, + int lp3, int rp3 ) { - global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); - global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - int nidx = get_global_id(0); - int idx_d1 = get_group_id(1); - int idx_d2 = get_group_id(2); + int i0 = get_global_id(0); + int i1 = get_group_id(1); + int i2 = get_group_id(2) % ne2; + int i3 = get_group_id(2) / ne2; - if (nidx >= d_ne0) { + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + global float * src0_ptr = (global float *)((global char *)src0 + src0_idx); + global float * dst_ptr = (global float *)((global char *)dst + dst_idx); - if (in_src_bounds) { - int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; - dst[dst_el_offset] = src0[src_el_offset]; - } else { - dst[dst_el_offset] = 0.0f; - } + bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3); + + *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f; } From b0560310aa6549cc94ae94a29212c687af8f2ca0 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 1 Oct 2025 07:56:36 +0000 Subject: [PATCH 017/104] vulkan: make ggml_vk_default_dispatcher support older vulkan headers (llama/16345) * make ggml_vk_default_dispatcher support older vulkan headers * simpilfy with using --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2608cbd0..003a9010 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -9,8 +9,14 @@ #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 // We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE // to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 namespace vk::detail { class DispatchLoaderDynamic; } -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include @@ -4538,9 +4544,8 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); -static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; - -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() { +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { return ggml_vk_default_dispatcher_instance; } From b73f67d3f6c0de80285f0114079261e9b8aaafec Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 1 Oct 2025 23:09:25 +0200 Subject: [PATCH 018/104] HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.0 (llama/16221) * HIP: Disable ROCWMMA fatt on CDNA when compiled against ROCWMMA 2.0.0 rocwmma 2.0.0 includes a bug in the code fakeing fp16 accumulation on CDNA * CUDA: Fix volta condition in ggml_cuda_should_use_wmma_fattn --- ggml/CMakeLists.txt | 1 - ggml/src/ggml-cuda/common.cuh | 29 ----------------- ggml/src/ggml-cuda/fattn-tile.cu | 5 +-- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 12 +++---- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 46 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn.cu | 4 +-- ggml/src/ggml-cuda/vendors/hip.h | 4 +++ ggml/src/ggml-hip/CMakeLists.txt | 10 ------ 8 files changed, 61 insertions(+), 50 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 56420587..6ce52ffc 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP" option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) -option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c4246b65..d51abbea 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) { #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) -#define FP16_MMA_AVAILABLE -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) - -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) - #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) @@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) { (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } -// Any FP16 tensor core instructions are available for ggml code. -static bool fp16_mma_available(const int cc) { -#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) - return false; -#else - if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || - GGML_CUDA_CC_IS_MTHREADS(cc)) { - return true; - } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { -#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - return true; -#else - return false; -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - } else { - return false; - } -#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) -} - // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 131a5099..68de623d 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile.cuh" +#include "fattn-wmma-f16.cuh" // kq_stride == number of KQ rows to process per iteration // kq_nbatch == number of K columns to load in parallel for KQ calculation @@ -190,10 +191,10 @@ static __global__ void flash_attn_tile( #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN NO_DEVICE_CODE; return; -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN if (use_logit_softcap && !(D == 128 || D == 256)) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 2219191f..6c90d6d5 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -6,19 +6,19 @@ #include "fattn-common.cuh" #include "fattn-wmma-f16.cuh" -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN #if !defined(GGML_USE_HIP) #include -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) namespace wmma = mtmusa::wmma; #else // GGML_USE_MUSA namespace wmma = nvcuda::wmma; #endif // GGML_USE_MUSA -#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#elif defined(GGML_USE_HIP) #include namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index beeea95e..1848d088 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,49 @@ #include "common.cuh" +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define GGML_USE_WMMA_FATTN +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#define GGML_USE_WMMA_FATTN +#elif defined(CDNA) +#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" +#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#if defined(RDNA3) +#define GGML_USE_WMMA_FATTN +#endif // defined(RDNA3) +#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#define GGML_USE_WMMA_FATTN +#elif defined(RDNA4) +#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +// WMMA flash attention requires FP16 matrix instructions to be available for ggml code. +static bool ggml_cuda_should_use_wmma_fattn(const int cc) { +#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_CDNA(cc)){ +#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + } else { + return false; + } +#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) +} + void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1cbd4f5b..d7736d36 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -222,7 +222,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { + if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { return BEST_FATTN_KERNEL_NONE; } break; @@ -300,7 +300,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For large batch sizes, use the WMMA kernel if possible: - if (fp16_mma_available(cc)) { + if (ggml_cuda_should_use_wmma_fattn(cc)) { return BEST_FATTN_KERNEL_WMMA_F16; } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 37386afc..890c1036 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -6,6 +6,10 @@ #include #include +#if defined(GGML_HIP_ROCWMMA_FATTN) +#include +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index d327b90c..0e2b1847 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -39,12 +39,6 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) -if (GGML_HIP_ROCWMMA_FATTN) - CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) - if (NOT ${FOUND_ROCWMMA}) - message(FATAL_ERROR "rocwmma has not been found") - endif() -endif() if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") @@ -117,10 +111,6 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() -if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) - add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) -endif() - if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() From e29508be8b402dfd5450245a669aa2c717a97239 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Thu, 2 Oct 2025 21:29:56 +0800 Subject: [PATCH 019/104] musa: update compile flags (llama/16265) Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/fattn-vec.cuh | 2 -- ggml/src/ggml-cuda/topk-moe.cu | 4 +--- ggml/src/ggml-musa/CMakeLists.txt | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 59c62553..89ab0f16 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -535,8 +535,6 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 039f2847..afe4aee2 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -13,7 +13,7 @@ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -204,8 +204,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); - cudaStream_t stream = ctx.stream(); - const int n_expert_used = weights->ne[1]; if (with_norm) { diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index cdb3818c..f8477a2e 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -56,7 +56,7 @@ if (MUSAToolkit_FOUND) set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) foreach(SOURCE ${GGML_SOURCES_MUSA}) - set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu") + set(COMPILE_FLAGS "-Od3 -fno-strict-aliasing -ffast-math -fsigned-char -x musa -mtgpu -fmusa-flush-denormals-to-zero") foreach(ARCH ${MUSA_ARCHITECTURES}) set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") endforeach() From 33ca8355c43b3f4f46d207f32908588a33e54724 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 2 Oct 2025 19:43:22 +0200 Subject: [PATCH 020/104] model : Apertus model implementation (llama/15852) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First attempt * No permute during convert (fixes qk tensors), proper norm application. * RoPE = NeoX * Coherence! * Migrate xielu params from tensors to hyperparameters * Simple CUDA kernel * Revert stupid LLM refactorings * Chat template support * configchecker / flake8 errors * Reorder unary.cu * I do conclude that LLMs are, in fact, stupid. * Fix after merge * Final newline * Make xIELU an UNARY_OP * Final newline * Correctly account for parameter shift * Argh. * Update ggml/src/ggml-cpu/unary-ops.cpp Co-authored-by: Georgi Gerganov * Refactor: remove unused methods, inline and factorize softplus, add const modifiers * Revert CUDA changes, implement xIELU as a separate OP * Pesky newline * Add float2half / half2float for F16 inputs/outputs * CUDA variants, attempt 2 * Actually, attempt 3 * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler * Missing convert header * Proper formula and reference for xIELU in the comments. * Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * Add tensor mappings for Apertus to global list instead * Fix lazy on scalars * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Johannes Gäßler * Add comment about the constraints on positive/negative alpha * Change `softplus` to `ggml_softplus` --------- Co-authored-by: Georgi Gerganov Co-authored-by: Johannes Gäßler Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 13 ++++ ggml/src/ggml-cpu/ggml-cpu.c | 1 + ggml/src/ggml-cpu/ops.cpp | 8 ++- ggml/src/ggml-cpu/unary-ops.cpp | 103 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/unary-ops.h | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 3 + ggml/src/ggml-cuda/unary.cu | 54 +++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 3 + ggml/src/ggml-impl.h | 3 + ggml/src/ggml.c | 27 ++++++++- 10 files changed, 212 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5028a9ce..f65eb75e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -576,6 +576,7 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_XIELU, GGML_UNARY_OP_COUNT, }; @@ -1150,6 +1151,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // xIELU activation function + // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) + // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions + // that constrain the positive and negative source alpha values respectively + GGML_API struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps); + // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dbc07301..eded6eb7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_XIELU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4..6275c830 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim @@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_XIELU: + { + ggml_compute_forward_xielu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 4fce569b..cf1a4615 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) { return sqrtf(x); } +static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) { + if (x > 0.0f) { + return alpha_p * x * x + beta * x; + } else { + const float min_x_eps = fminf(x, eps); + return (expm1f(min_x_eps) - x) * alpha_n + beta * x; + } +} + static inline float op_sin(float x) { return sinf(x); } @@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { } } +template +static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +// Extend vec_unary_op to support functors +template +static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +// Extend apply_unary_op to support functors +template +static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op); + } +} + +// Generic dispatcher for functors +template +static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } @@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } + +void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { + return op_xielu(f, alpha_n, alpha_p, beta, eps); + }; + + unary_op_functor(params, dst, xielu_op_params); +} + diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index b1ade2c8..697c1e0d 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b7e81b21..26e72bbc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2334,6 +2334,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5aff8a87..3c564566 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,4 +1,5 @@ #include "unary.cuh" +#include "convert.cuh" static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); @@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +/* CUDA kernel + launcher for xIELU */ + +template +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = ggml_cuda_cast(x[i]); + + const float gate_pos = (xi > 0.0f); + const float y_pos = alpha_p * xi * xi + beta * xi; + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = ggml_cuda_cast(out); +} + +template +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index da3caf1d..8e7644fc 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,7 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 #define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 86a1ebf6..d0fb3bcc 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline float ggml_softplus(float input) { + return (input > 20.0f) ? input : logf(1 + expf(input)); +} // // logging // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aecbdad5..7d50b42a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "XIELU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); - +static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu + +struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); + ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 3, beta); + ggml_set_op_params_f32(result, 4, eps); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( From 27ebde6afdf596768b12756a02a4e9d35f9c5cb0 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 2 Oct 2025 11:00:31 -0700 Subject: [PATCH 021/104] ggml webgpu: add support for soft_max, optimize rms_norm (llama/16357) * Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 7 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 193 ++++++++-- .../ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 43 ++- .../wgsl-shaders/soft_max.tmpl.wgsl | 344 ++++++++++++++++++ ggml/src/ggml.c | 9 + 5 files changed, 552 insertions(+), 44 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f65eb75e..60c6b63d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1630,6 +1630,13 @@ extern "C" { float scale, float max_bias); + GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias); + GGML_API void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 93200a4d..de68c568 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -28,6 +28,7 @@ /* Constants */ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 #define WEBGPU_MUL_MAT_WG_SIZE 64 #define WEBGPU_NUM_PARAM_BUFS 100 #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters @@ -35,6 +36,9 @@ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +// For operations which process a row in parallel, this seems like a reasonable default +#define WEBGPU_ROW_SPLIT_WG_SIZE 64 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -130,15 +134,16 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type - wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace - wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace - wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split - wgpu::ComputePipeline scale_pipeline[2]; // inplace + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace + wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace size_t memset_bytes_per_thread; @@ -256,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { }), UINT64_MAX); } else { - // existing callbacks, wait on them - ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX); + // WebGPU implementations may limit the number of futures that can be waited on at once, + // so wait in batches (64 is what Dawn supports). + for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) { + size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size()); + ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX); + } ctx->callback_futures.clear(); } } @@ -726,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src), ggml_op_name(dst->op)); } @@ -912,6 +919,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens ggml_op_name(dst->op)); } +static void ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here + const int has_sink = (src2 != nullptr); + float max_bias; + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + mask_type < 2 ? (uint32_t) src1->ne[2] : 0, + mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &max_bias, + *(uint32_t *) &n_head_log2, + *(uint32_t *) &m0, + *(uint32_t *) &m1 + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) } + }; + uint32_t binding_num = 1; + if (mask_type < 2) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + binding_num++; + } + if (has_sink) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + binding_num++; + } + if (!inplace) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + ggml_nrows(dst), ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -1237,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// The max workgroup size is a common constant -static std::vector ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) { +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { std::vector constants(1); constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->max_wg_size_x; + constants[0].value = wg_size; return constants; } @@ -1309,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32, @@ -1363,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], wgsl_cpy_f32_f32, "cpy_f32_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], @@ -1375,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", @@ -1387,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", @@ -1399,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", @@ -1411,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", @@ -1423,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, @@ -1431,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, "rope_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], @@ -1451,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); // reglu ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], wgsl_reglu_f32, "reglu_f32", constants); @@ -1505,13 +1585,43 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, "scale_f32_inplace", constants); } +static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32, + "soft_max_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace, + "soft_max_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink, + "soft_max_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1], + wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32, + "soft_max_f32_mask_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1], + wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16, + "soft_max_f32_mask_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1], + wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0], + wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1], + wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0], + wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1], + wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", + constants); +} + static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); @@ -1593,6 +1703,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || @@ -1623,7 +1734,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); break; case GGML_OP_SET_ROWS: - supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); + supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || @@ -1698,13 +1809,25 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const default: break; } -#ifdef GGML_WEBGPU_DEBUG - if (!supports_op) { - WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) - << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") - << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || + (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || + (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || + (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + supports_op = false; + WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); + } + + if (!supports_op) { + WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: " + << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + } else { + WEBGPU_LOG_DEBUG("ggml_webgpu op supported: " + << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); } -#endif return supports_op; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index a275eeb9..4f72bb1c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -71,14 +71,14 @@ var src: array; DECLS override wg_size: u32; +var scratch: array; + @compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { // one thread per row - var i = gid.x; + var i = wid.x; let i3 = i / (params.ne2 * params.ne1); i = i % (params.ne2 * params.ne1); let i2 = i / params.ne1; @@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + wg_size - 1) / wg_size; + var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += src[i_src_row + j] * src[i_src_row + j]; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + sum += pow(src[i_src_row + col], 2.0); + col += wg_size; } + + scratch[lid.x] = sum; + workgroupBarrier(); + var offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); - for (var j: u32 = 0; j < params.ne0; j++) { - update(i_src_row + j, i_dst_row + j, scale); + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_src_row + col, i_dst_row + col, scale); + col += wg_size; } } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl new file mode 100644 index 00000000..64ab576c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -0,0 +1,344 @@ +#define(VARIANTS) +[ + { + "SHADER_NAME": "soft_max_f32", + "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_inplace", + "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_sink", + "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_sink_inplace", + "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_inplace", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_inplace", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_sink", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_sink", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] + } +] +#end(VARIANTS) + +#define(DECLS) + +#decl(BASE_BINDINGS) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#enddecl(BASE_BINDINGS) + +#decl(BASE_BINDINGS_INPLACE) +@group(0) @binding(1) +var params: Params; +#enddecl(BASE_BINDINGS_INPLACE) + +#decl(SINK_BINDINGS) +@group(0) @binding(1) +var sinks: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(SINK_BINDINGS) + +#decl(SINK_BINDINGS_INPLACE) +@group(0) @binding(1) +var sinks: array; + +@group(0) @binding(2) +var params: Params; +#enddecl(SINK_BINDINGS_INPLACE) + +#decl(MASK_BINDINGS) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(MASK_BINDINGS) + +#decl(MASK_BINDINGS_INPLACE) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var params: Params; +#enddecl(MASK_BINDINGS_INPLACE) + +#decl(MASK_SINK_BINDINGS) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var sinks: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; +#enddecl(MASK_SINK_BINDINGS) + +#decl(MASK_SINK_BINDINGS_INPLACE) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var sinks: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(MASK_SINK_BINDINGS_INPLACE) + +#decl(NOT_INPLACE) +fn inter_value(i: u32) -> f32 { + return dst[i]; +} + +fn update(i: u32, val: f32) { + dst[i] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +fn inter_value(i: u32) -> f32 { + return src[i]; +} + +fn update(i: u32, val: f32) { + src[i] = val; +} +#enddecl(INPLACE) + +#decl(NO_MASK) +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#enddecl(NO_MASK) + +#decl(MASK) +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} +#enddecl(MASK) + +#decl(NO_SINK) +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} + +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#enddecl(NO_SINK) + +#decl(SINK) +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} + +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#enddecl(SINK) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +const CACHE_SIZE: u32 = 16; + +override wg_size: u32; +var scratch: array; + +@compute @workgroup_size(wg_size) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + wg_size - 1) / wg_size; + + let head = f32(i2); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + var cache: array; + + var max_val = lower_max_bound(i2); + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + col += wg_size; + } + + scratch[lid.x] = max_val; + workgroupBarrier(); + var offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); + } + offset = offset / 2; + workgroupBarrier(); + } + let row_max = scratch[0]; + + var sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); + } + col += wg_size; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + let row_sum = add_sinks(scratch[0], i2, row_max); + + let sum_recip = 1.0 / row_sum; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + col += wg_size; + } +} +#end(SHADER) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7d50b42a..2bce1375 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3852,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true); +} + void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks) { From fd11cd97abcee469e15f408395a54ede3f47bc27 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 3 Oct 2025 03:33:08 -0500 Subject: [PATCH 022/104] vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD) (llama/16316) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 --- .../vulkan-shaders/flash_attn.comp | 3 +- .../vulkan-shaders/flash_attn_cm1.comp | 4 ++- .../vulkan-shaders/flash_attn_cm2.comp | 28 +++++++++++++++---- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 003a9010..def8dc96 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2614,8 +2614,6 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; @@ -7457,8 +7455,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { aligned = false; } - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 43b906e5..e4247502 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -153,12 +153,13 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } else { masksh[c][r] = float(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ddb1246e..e76dbb4d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -201,11 +201,13 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ab647e9b..a65553a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -154,15 +154,31 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - coopmat mv; + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopmat mv; - S += slopeMat*coopmat(mv); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } } // Clear padding elements to -inf, so they don't contribute to rowmax From 90bdcf2ef62b5f5bbd2d003e857392c51e496a5e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 3 Oct 2025 04:52:46 -0500 Subject: [PATCH 023/104] vulkan: Fix FA coopmat1 invalid array indexing (llama/16365) When computing sinks, the cm1 shader was looping r from 0 to Br rather than to rows_per_thread. I must have copied this from the scalar path (where it is correct), and somehow it wasn't causing failures on current drivers. --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index e76dbb4d..0507df2d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -358,8 +358,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; From 2e6888089f2a15f0220a71d939317f60d5a88af4 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 3 Oct 2025 05:50:46 -0500 Subject: [PATCH 024/104] vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE (llama/16354) * vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE Replace maxMemoryAllocationSize check with maxBufferSize when creating buffers. The maxMemoryAllocationSize limit is a "soft" limit and allocations can succeed beyond that limit. This allows > 4GB buffers to be allocated on some implementations (e.g. NVIDIA) and tensors this large can be used for im2col and mul_mat. For temporary buffers (prealloc_x/y/etc) check against maxStorageBufferRange. I'm not sure this check is ideal, but we always use these buffers as a single full size binding and the limit may be smaller than maxMemoryAllocationSize or maxBufferSize, so I think this is reasonable. Replace descriptor range uses of VK_WHOLE_SIZE with a manually computed range. The maxStorageBufferRange may be smaller than the maxBufferSize or maxMemoryAllocationSize (and the Vulkan spec warns about this in a note) and it's invalid usage if VK_WHOLE_SIZE computes a range larger than maxStorageBufferRange. With this change, it should be possible to generate videos using wan networks in stable-diffusion.cpp. * vulkan: Add env var GGML_VK_FORCE_MAX_BUFFER_SIZE and use stoull --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 194 +++++++++++++-------------- 1 file changed, 95 insertions(+), 99 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index def8dc96..3cd89c71 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -393,6 +393,7 @@ struct vk_device_struct { vk::PhysicalDeviceProperties properties; std::string name; uint64_t max_memory_allocation_size; + uint64_t max_buffer_size; uint64_t suballocation_block_size; bool fp16; bool bf16; @@ -1563,6 +1564,12 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) { + const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset}, + VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + return range; +} + // Wait for ctx->fence to be signaled. static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep @@ -2012,8 +2019,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); - if (size > device->max_memory_allocation_size) { - throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + if (size > device->max_buffer_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); } vk_buffer buf = std::make_shared(); @@ -2159,8 +2166,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { buf.reset(); } -static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { - return { buf, 0, VK_WHOLE_SIZE }; +static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) { + return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) }; } static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { @@ -3853,17 +3860,27 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { - device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); } else if (maintenance4_support) { device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); } else { device->max_memory_allocation_size = props3.maxMemoryAllocationSize; } + const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE"); + + if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) { + device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE); + } else if (maintenance4_support) { + device->max_buffer_size = props4.maxBufferSize; + } else { + device->max_buffer_size = device->max_memory_allocation_size; + } + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { - device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); + device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE); } else { // Limit batching of allocations to 1GB by default to avoid fragmentation issues device->suballocation_block_size = 1024*1024*1024; @@ -6148,9 +6165,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || - (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6225,7 +6242,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (x_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); @@ -6237,7 +6254,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6248,7 +6265,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6270,14 +6287,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; } - // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange. - VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); - // compute ggml_vk_matmul( ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, - { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n @@ -6444,8 +6458,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; } if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6510,7 +6524,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); @@ -6519,7 +6533,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6530,7 +6544,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6929,8 +6943,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6997,7 +7011,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (x_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, @@ -7010,7 +7024,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -7143,8 +7157,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -7210,7 +7224,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); @@ -7219,7 +7233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -7494,7 +7508,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; - if (split_k_size > ctx->device->max_memory_allocation_size) { + if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } if (ctx->prealloc_size_split_k < split_k_size) { @@ -7616,12 +7630,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), }, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been @@ -7633,21 +7647,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { - vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), }, pc, { workgroups_x, workgroups_y, workgroups_z }); } @@ -8356,18 +8370,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; - uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; - uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; - uint64_t d_sz = ggml_type_size(dst->type) * ned; - vk_buffer d_D = dst_buf_ctx->dev_buffer; - // Workaround for tiny tensor inputs on ROPE - if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { - y_sz = VK_WHOLE_SIZE; - } - GGML_ASSERT(d_D != nullptr); uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; if(!src0_uma) { @@ -8392,26 +8396,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); - y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; - d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); - - if (x_buf_offset + x_sz >= d_X->size) { - x_sz = VK_WHOLE_SIZE; - } - if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { - y_sz = VK_WHOLE_SIZE; - } - if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { - z_sz = VK_WHOLE_SIZE; - } - if (d_buf_offset + d_sz >= d_D->size) { - d_sz = VK_WHOLE_SIZE; - } - } - std::array elements; // Single call if dimension 2 is contiguous @@ -8602,19 +8586,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; } - if (!op_supports_incontiguous) { - if (x_sz != VK_WHOLE_SIZE) { - x_sz *= ne02 * ne03; + uint64_t x_sz, y_sz, z_sz, d_sz; + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); } - if (use_src1 && y_sz != VK_WHOLE_SIZE) { - y_sz *= ne12 * ne13; + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); } - if (use_src2 && z_sz != VK_WHOLE_SIZE) { - z_sz *= ne22 * ne23; + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); } - if (d_sz != VK_WHOLE_SIZE) { - d_sz *= ned2 * ned3; + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); } + } else { + x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; + y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; + z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; + d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; } if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { @@ -8624,7 +8620,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz }, - vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE }, + ggml_vk_subbuffer(ctx, d_A, a_buf_offset), }, pc, elements); } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer @@ -8817,18 +8813,18 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, static_assert(MAX_PARAMETER_COUNT == 12); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE }, + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + ggml_vk_subbuffer(ctx, buf[7], offset[7]), + ggml_vk_subbuffer(ctx, buf[8], offset[8]), + ggml_vk_subbuffer(ctx, buf[9], offset[9]), + ggml_vk_subbuffer(ctx, buf[10], offset[10]), + ggml_vk_subbuffer(ctx, buf[11], offset[11]), }, pc, elements); } @@ -10002,7 +9998,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { ggml_vk_matmul( - ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, split_k, batch, batch, batch, 1, 1, n @@ -10313,7 +10309,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // // vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); // ggml_vk_ctx_begin(ctx->device, subctx); -// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne); // ggml_vk_ctx_end(subctx); // // auto begin = std::chrono::high_resolution_clock::now(); From a70144a873686c5534c05d90912a12f266a92798 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 3 Oct 2025 13:49:08 +0200 Subject: [PATCH 025/104] ggml : fix graph reallocation with multiple chunks (llama/16396) reallocation is needed if a single chunk grows in size, even if total allocation size stays the same or is lower --- ggml/src/ggml-alloc.c | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index fa46f3b4..929bc448 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) { free(alloc); } -static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) { - size_t max_size = 0; - for (int i = 0; i < alloc->n_chunks; i++) { - max_size += alloc->chunks[i]->max_size; - } - return max_size; +static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) { + return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0; } @@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) { free(buf); } -static int ggml_vbuffer_n_chunks(struct vbuffer * buf) { - int n = 0; - while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++; - return n; +static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) { + return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0; } static size_t ggml_vbuffer_size(struct vbuffer * buf) { @@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); - // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views - if (new_size > cur_size || galloc->buffers[i] == NULL) { + bool realloc = galloc->buffers[i] == NULL; + size_t new_size = 0; + for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) { + size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0; + size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c); + new_size += new_chunk_size; + if (new_chunk_size > cur_chunk_size) { + realloc = true; + } + } + if (realloc) { #ifndef NDEBUG + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif From 93c1305565b27b6c0d42aa8020dcb4e116a679b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 3 Oct 2025 19:18:56 +0300 Subject: [PATCH 026/104] metal : fix loop bound in ggml_mem_ranges (llama/16412) --- ggml/src/ggml-metal/ggml-metal-common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index dc7d241c..95627d38 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t } bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { ggml_mem_ranges_add_src(mrs, tensor->src[i]); } @@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * } bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { return false; From 49e0a426f356113c0878f892f2603725eeb78463 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 11 Oct 2025 17:59:36 +0300 Subject: [PATCH 027/104] vulkan : incremental shader builds (llama/16341) * vulkan (DRAFT): split shader generation by GLSL source file, to improve incremental build times * support dep-files so shaders are recompiled if their included files change * rename shader files which are used as "headers" to use .glsl extension * move glslc extension detection shaders to separate folders * the above is to prevent them from getting glob'd with the actual compute shaders that need to be compiled * vulkan : only write embedded shader .hpp/.cpp when they change * avoid recompiling ggml-vulkan.cpp when editing shaders * pass single --source argument instead of --input-dir & --filter to shader gen * check for source file match earlier * fix hang in vulkan-shaders-gen when there are compilation errors * early out did not decrement compile_count * clean up * fix glslc integer dot product test * unconditionally write the embedded shader cpp output * replace output filepath in generated dep-files to match output in CMakeLists --------- Co-authored-by: Jeff Bolz --- ggml/src/ggml-vulkan/CMakeLists.txt | 45 ++- ggml/src/ggml-vulkan/vulkan-shaders/acc.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/add.comp | 4 +- .../ggml-vulkan/vulkan-shaders/add_id.comp | 2 +- .../ggml-vulkan/vulkan-shaders/argmax.comp | 4 +- .../ggml-vulkan/vulkan-shaders/argsort.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 4 +- .../ggml-vulkan/vulkan-shaders/concat.comp | 4 +- .../vulkan-shaders/contig_copy.comp | 4 +- .../ggml-vulkan/vulkan-shaders/conv2d_dw.comp | 2 +- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 2 +- .../vulkan-shaders/conv_transpose_1d.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/copy.comp | 4 +- .../vulkan-shaders/copy_from_quant.comp | 6 +- .../vulkan-shaders/copy_to_quant.comp | 8 +- ggml/src/ggml-vulkan/vulkan-shaders/cos.comp | 4 +- .../vulkan-shaders/count_equal.comp | 4 +- .../vulkan-shaders/dequant_f32.comp | 2 +- ...{dequant_funcs.comp => dequant_funcs.glsl} | 2 +- ..._funcs_cm2.comp => dequant_funcs_cm2.glsl} | 2 +- .../{dequant_head.comp => dequant_head.glsl} | 2 +- .../vulkan-shaders/dequant_iq1_m.comp | 2 +- .../vulkan-shaders/dequant_iq1_s.comp | 2 +- .../vulkan-shaders/dequant_iq2_s.comp | 2 +- .../vulkan-shaders/dequant_iq2_xs.comp | 2 +- .../vulkan-shaders/dequant_iq2_xxs.comp | 2 +- .../vulkan-shaders/dequant_iq3_s.comp | 2 +- .../vulkan-shaders/dequant_iq3_xxs.comp | 2 +- .../vulkan-shaders/dequant_iq4_nl.comp | 2 +- .../vulkan-shaders/dequant_iq4_xs.comp | 2 +- .../vulkan-shaders/dequant_mxfp4.comp | 2 +- .../vulkan-shaders/dequant_q2_k.comp | 2 +- .../vulkan-shaders/dequant_q3_k.comp | 2 +- .../vulkan-shaders/dequant_q4_0.comp | 2 +- .../vulkan-shaders/dequant_q4_1.comp | 2 +- .../vulkan-shaders/dequant_q4_k.comp | 2 +- .../vulkan-shaders/dequant_q5_0.comp | 2 +- .../vulkan-shaders/dequant_q5_1.comp | 2 +- .../vulkan-shaders/dequant_q5_k.comp | 2 +- .../vulkan-shaders/dequant_q6_k.comp | 2 +- .../vulkan-shaders/dequant_q8_0.comp | 2 +- .../vulkan-shaders/diag_mask_inf.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/div.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/exp.comp | 6 +- .../bfloat16.comp} | 0 .../coopmat.comp} | 0 .../coopmat2.comp} | 0 .../integer_dot.comp} | 0 .../vulkan-shaders/flash_attn.comp | 4 +- ...sh_attn_base.comp => flash_attn_base.glsl} | 0 .../vulkan-shaders/flash_attn_cm1.comp | 4 +- .../vulkan-shaders/flash_attn_cm2.comp | 6 +- .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 4 +- .../ggml-vulkan/vulkan-shaders/geglu_erf.comp | 4 +- .../vulkan-shaders/geglu_quick.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp | 4 +- .../ggml-vulkan/vulkan-shaders/gelu_erf.comp | 4 +- .../vulkan-shaders/gelu_quick.comp | 4 +- ...ary_head.comp => generic_binary_head.glsl} | 4 +- .../{generic_head.comp => generic_head.glsl} | 0 ...nary_head.comp => generic_unary_head.glsl} | 0 .../ggml-vulkan/vulkan-shaders/get_rows.comp | 4 +- .../vulkan-shaders/get_rows_quant.comp | 6 +- .../{glu_head.comp => glu_head.glsl} | 2 +- .../{glu_main.comp => glu_main.glsl} | 0 .../vulkan-shaders/group_norm.comp | 4 +- .../vulkan-shaders/hardsigmoid.comp | 4 +- .../ggml-vulkan/vulkan-shaders/hardswish.comp | 4 +- .../ggml-vulkan/vulkan-shaders/im2col.comp | 5 +- .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 5 +- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 4 +- .../vulkan-shaders/leaky_relu.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/mul.comp | 4 +- .../vulkan-shaders/mul_mat_vec.comp | 2 +- ...at_vec_base.comp => mul_mat_vec_base.glsl} | 4 +- .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq1_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq2_xs.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq2_xxs.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq3_xxs.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 2 +- .../vulkan-shaders/mul_mat_vecq.comp | 4 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 4 +- .../vulkan-shaders/mul_mm_cm2.comp | 6 +- .../{mul_mm_funcs.comp => mul_mm_funcs.glsl} | 0 .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 +- ...{mul_mmq_funcs.comp => mul_mmq_funcs.glsl} | 2 +- .../ggml-vulkan/vulkan-shaders/multi_add.comp | 6 +- ggml/src/ggml-vulkan/vulkan-shaders/norm.comp | 4 +- .../vulkan-shaders/opt_step_adamw.comp | 4 +- .../vulkan-shaders/opt_step_sgd.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/pad.comp | 2 +- .../ggml-vulkan/vulkan-shaders/pool2d.comp | 2 +- .../vulkan-shaders/quantize_q8_1.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/relu.comp | 4 +- .../ggml-vulkan/vulkan-shaders/repeat.comp | 4 +- .../vulkan-shaders/repeat_back.comp | 4 +- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 4 +- .../vulkan-shaders/rms_norm_back.comp | 4 +- .../vulkan-shaders/rms_norm_partials.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/roll.comp | 4 +- .../{rope_head.comp => rope_head.glsl} | 4 +- .../vulkan-shaders/rope_multi.comp | 2 +- .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 2 +- .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 2 +- .../vulkan-shaders/rope_vision.comp | 2 +- .../vulkan-shaders/{rte.comp => rte.glsl} | 0 .../src/ggml-vulkan/vulkan-shaders/scale.comp | 4 +- .../ggml-vulkan/vulkan-shaders/sigmoid.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/silu.comp | 4 +- .../ggml-vulkan/vulkan-shaders/silu_back.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/sin.comp | 4 +- .../ggml-vulkan/vulkan-shaders/soft_max.comp | 2 +- .../vulkan-shaders/soft_max_back.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp | 4 +- .../ggml-vulkan/vulkan-shaders/square.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/sub.comp | 4 +- .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 2 +- .../ggml-vulkan/vulkan-shaders/swiglu.comp | 4 +- .../vulkan-shaders/swiglu_oai.comp | 4 +- ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp | 4 +- .../vulkan-shaders/timestep_embedding.comp | 2 +- .../vulkan-shaders/{types.comp => types.glsl} | 0 .../ggml-vulkan/vulkan-shaders/upscale.comp | 2 +- .../vulkan-shaders/{utils.comp => utils.glsl} | 0 .../vulkan-shaders/vulkan-shaders-gen.cpp | 294 +++++++++++------- 133 files changed, 404 insertions(+), 315 deletions(-) rename ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp => dequant_funcs.glsl} (99%) rename ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp => dequant_funcs_cm2.glsl} (99%) rename ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp => dequant_head.glsl} (91%) rename ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp => feature-tests/bfloat16.comp} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp => feature-tests/coopmat.comp} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp => feature-tests/coopmat2.comp} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp => feature-tests/integer_dot.comp} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp => flash_attn_base.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp => generic_binary_head.glsl} (97%) rename ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp => generic_head.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp => generic_unary_head.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp => glu_head.glsl} (95%) rename ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp => glu_main.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp => mul_mat_vec_base.glsl} (99%) rename ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp => mul_mm_funcs.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{mul_mmq_funcs.comp => mul_mmq_funcs.glsl} (99%) rename ggml/src/ggml-vulkan/vulkan-shaders/{rope_head.comp => rope_head.glsl} (97%) rename ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp => rte.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{types.comp => types.glsl} (100%) rename ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp => utils.glsl} (100%) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index b97e7bf9..83a83887 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,6 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) +cmake_policy(SET CMP0116 NEW) find_package(Vulkan COMPONENTS glslc REQUIRED) @@ -54,25 +55,25 @@ if (Vulkan_FOUND) # Test all shader extensions test_shader_extension_support( "GL_KHR_cooperative_matrix" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp" "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_NV_cooperative_matrix2" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp" "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_integer_dot_product" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_bfloat16" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp" "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" ) @@ -160,7 +161,6 @@ if (Vulkan_FOUND) set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") - set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") @@ -176,24 +176,35 @@ if (Vulkan_FOUND) add_custom_command( OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} - COMMAND ${_ggml_vk_genshaders_cmd} - --glslc ${Vulkan_GLSLC_EXECUTABLE} - --input-dir ${_ggml_vk_input_dir} --output-dir ${_ggml_vk_output_dir} --target-hpp ${_ggml_vk_header} - --target-cpp ${_ggml_vk_source} - --no-clean - - DEPENDS ${_ggml_vk_shader_files} - ${_ggml_vk_shaders_gen_sources} + DEPENDS ${_ggml_vk_shaders_gen_sources} vulkan-shaders-gen - - COMMENT "Generate vulkan shaders" + COMMENT "Generate vulkan shaders header" ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header}) - target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + foreach (file_full ${_ggml_vk_shader_files}) + get_filename_component(file ${file_full} NAME) + set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp") + + add_custom_command( + OUTPUT ${_ggml_vk_target_cpp} + DEPFILE ${_ggml_vk_target_cpp}.d + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --source ${file_full} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_target_cpp} + DEPENDS ${file_full} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders for ${file}" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp}) + endforeach() else() message(WARNING "Vulkan not found") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index d896f1ef..5084a70e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 00cf2dd6..3bcfe690 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -6,8 +6,8 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp index 3ae8f011..495249d5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -2,7 +2,7 @@ #extension GL_EXT_control_flow_attributes : require -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index a1d4c240..7c128776 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index dc53a401..c81b8445 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp index 1e5cb8da..65343189 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp index 9ee2f1fa..e4046983 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index 6567a8c5..ca1a3ac2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp index 938c74da..70a30148 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 44a64ddc..0367e80b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -11,7 +11,7 @@ # extension GL_KHR_shader_subgroup_shuffle : enable #endif -#include "types.comp" +#include "types.glsl" // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp index b17b4e83..5217e18b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index f476a2e3..9f8bfd3c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 978d4300..06df5095 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -1,8 +1,8 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" +#include "dequant_funcs.glsl" #if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) // 16 invocations needed for init_iq_shmem diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index bc2e1f2d..b8c40eec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,7 +1,7 @@ #version 450 -#include "rte.comp" -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -14,7 +14,7 @@ const uint BLOCK_SIZE = 32; layout (binding = 0) readonly buffer S {float data_s[];}; #if defined(SET_ROWS) -#include "generic_binary_head.comp" +#include "generic_binary_head.glsl" layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; @@ -25,7 +25,7 @@ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; #endif #else -#include "generic_unary_head.comp" +#include "generic_unary_head.glsl" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp index 0b8d02f5..db6865db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp index d9345497..e75df667 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_head.comp" +#include "types.glsl" +#include "generic_head.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp index a4d3fca5..765afffa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 73fef4fa..0d98f5a9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #endif -#include "types.comp" +#include "types.glsl" #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 706540fd..6a5bb457 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,5 +1,5 @@ -#include "types.comp" +#include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl similarity index 91% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl index 8d806435..addceafa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl @@ -10,4 +10,4 @@ layout (push_constant) uniform parameter uint nel; } p; -#include "types.comp" +#include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index b604c188..637c95fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp index fd1e4e30..d1cbc5e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 127c7b64..78490162 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp index a08331c4..9b8ce0a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index 0ae9acd0..aacf07d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index e4f42be9..f2c20b1d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index 19c7fdee..671c1f4a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 46d9ad15..8f7833ea 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp index f930852a..a3136997 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ee496e9d..ffba5a77 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index d4e4e6ba..58dc2e5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index 3661f771..0c90be8b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp index 40818532..b92b2921 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp index 2f27eee6..6b63cbe5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 1370db36..8b7be557 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp index b20b8052..f1b0bac8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp index dc59fe3b..c495b31f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 3f3b839e..6bc04670 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 9cf34256..c8d6fcb4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp index bd1344a8..10844ddf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 26d8bc22..9cef8a8e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -10,7 +10,7 @@ layout (push_constant) uniform parameter uint n_past; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp index 9fb69c6c..572472f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index a3941372..b69d4ddb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,8 +1,8 @@ #version 450 -#include "rte.comp" -#include "generic_head.comp" -#include "types.comp" +#include "rte.glsl" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e4247502..62acbf10 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -8,8 +8,8 @@ #extension GL_KHR_shader_subgroup_shuffle : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 0507df2d..2066a05b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -10,8 +10,8 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index a65553a4..910da1ab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -16,9 +16,9 @@ #extension GL_KHR_shader_subgroup_vote : enable #extension GL_EXT_null_initializer : enable -#include "types.comp" -#include "dequant_funcs_cm2.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "dequant_funcs_cm2.glsl" +#include "flash_attn_base.glsl" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp index f4268ed2..e017b503 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -10,4 +10,4 @@ float op(float a, float b) { return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp index cbd4cb36..759a1848 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation // ref: https://www.johndcook.com/blog/python_erf/ @@ -24,4 +24,4 @@ float op(float a, float b) { return 0.5f * a * (1.0f + erf_approx) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp index 3a2a6897..c4032ab2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_QUICK_COEF = -1.702f; @@ -8,4 +8,4 @@ float op(float a, float b) { return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp index 4cc7a68c..a95c2525 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp index 5fd5a5e7..58375aba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp index e6e6fcfd..bfdfe218 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 750e7857..99595fc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,8 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" -#include "utils.comp" +#include "rte.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 7ef75cd7..76d83041 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 339f905f..9dba437e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_binary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" +#include "dequant_funcs.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl similarity index 95% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 51d70869..21689893 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp index b6a0d564..bdf97dbb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp index 1da252cc..b4dbdf31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp index 3afc5882..1ec31591 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index f0f19a01..1827d647 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,9 +3,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 9faa636a..4bf8b4ca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,9 +4,8 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index deba8c39..83ef2f87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp index d90a99ae..b281e855 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp index 43de19df..02ef1eac 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index bb429dd5..9a03925c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index f761391e..450dee04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -11,7 +11,7 @@ #define EXPERT_COUNT 8 #endif -#include "types.comp" +#include "types.glsl" #ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -32,7 +32,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 3) readonly buffer IDS {int data_ids[];}; #endif -#include "dequant_funcs.comp" +#include "dequant_funcs.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index e4acbd4f..4cb29238 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index 309da099..0b74b332 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index 8d01536f..e424af12 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index c4960432..0cd906db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp index 94d4b92e..71bd72d1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index f021e404..a4b9ab1f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp index 3fe9dc3a..40849c69 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 423ceb8a..03ed25d3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index e91724a2..528f224d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index f9cde064..21d07d2e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 6c84ef3c..9e46c89a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d53d9ee0..d7a7f642 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 8fb314fa..64293f6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -6,13 +6,13 @@ #define MMQ #define B_TYPE block_q8_1_x4 -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" uint a_offset, b_offset, d_offset; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 3cb24412..85400ac5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -28,7 +28,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -195,7 +195,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mm_funcs.comp" +#include "mul_mm_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 0e3065e0..2e04baa4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -18,8 +18,8 @@ #extension GL_EXT_bfloat16 : enable #endif -#include "types.comp" -#include "utils.comp" +#include "types.glsl" +#include "utils.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -71,7 +71,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#include "dequant_funcs_cm2.comp" +#include "dequant_funcs_cm2.glsl" #else #define DECODEFUNCA diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index f36add62..b5d761c0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -20,7 +20,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -110,7 +110,7 @@ shared u16vec2 row_ids[4096]; shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index cdfb230f..fe71eb13 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -#include "types.comp" +#include "types.glsl" // Each iqs value maps to a 32-bit integer diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 854a2ad8..1e8f694a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,9 +8,9 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.comp" -#include "types.comp" -#include "utils.comp" +#include "rte.glsl" +#include "types.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter2 { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index 6627a50b..cc3ea0b7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp index e0214fe7..1f05f922 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp index 6426dede..1251f9cc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 0d81220c..f3c81768 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp index b6124411..d9d7166e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 145c9fbd..0f3c6ca8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -17,7 +17,7 @@ layout (push_constant) uniform parameter uint ne; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp index 0073d8f7..86be2669 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return max(a, 0.0f) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 4f806270..5725cef2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp index 1568b141..8f4b9a86 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp index d8627993..87df7829 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 41197e93..d5b211ff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp index 76009f3d..87707fc1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp index ba4677c2..4618b2c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #extension GL_KHR_shader_subgroup_arithmetic : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp index b9abe8de..68fbd0c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index 00e203e7..50fc1f1e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -1,8 +1,8 @@ -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 5808710c..111286b4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 366a7b1c..06e095be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 9643bca9..6ba95754 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index cedacc4d..d37d1c10 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/rte.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index f10b0a02..35ec726a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" const uint num_threads = 128; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 5c9e5c35..32298d43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp index 4d36f88e..7d1cc6f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp index f9afa9b1..e5d949ff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp index d7c15a16..61f17b2f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 5f20a1ee..dca0d896 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -23,7 +23,7 @@ layout (push_constant) uniform parameter uint has_sinks; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 144ea58e..d873332e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp index 4bc697b9..70daad6c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp index ef43598b..4eb56afc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp index 72353cc3..bc924b52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -2,8 +2,8 @@ #extension GL_EXT_shader_16bit_storage : require -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 759204af..bc22aa7b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp index a28e7c6c..4fee433a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return a / (1.0f + exp(-a)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp index 970750ee..bda9dea2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { float xi = min(a, p.limit); @@ -11,4 +11,4 @@ float op(float a, float b) { return out_glu; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 8a6f868f..7b5eb413 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp index ce8e0944..16055654 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter uint max_period; } p; -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 256 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/types.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/types.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 74771def..154a2172 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter float sf0; float sf1; float sf2; float sf3; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/utils.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 84bb9df9..e2726f1f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -34,13 +34,13 @@ std::mutex lock; std::vector> shader_fnames; +std::locale c_locale("C"); std::string GLSLC = "glslc"; -std::string input_dir = "vulkan-shaders"; +std::string input_filepath = ""; std::string output_dir = "/tmp"; -std::string target_hpp = "ggml-vulkan-shaders.hpp"; -std::string target_cpp = "ggml-vulkan-shaders.cpp"; -bool no_clean = false; +std::string target_hpp = ""; +std::string target_cpp = ""; const std::vector type_names = { "f32", @@ -75,6 +75,7 @@ enum MatMulIdType { }; namespace { + void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 HANDLE stdout_read, stdout_write; @@ -232,16 +233,87 @@ std::string basename(const std::string &path) { return path.substr(path.find_last_of("/\\") + 1); } +std::stringstream make_generic_stringstream() { + std::stringstream ss; + ss.imbue(c_locale); + return ss; +} + +std::string read_binary_file(const std::string& path, bool may_not_exist = false) { + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + if (!may_not_exist) { + std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n"; + } + return {}; + } + + fseek(f, 0, SEEK_END); + size_t size = ftell(f); + fseek(f, 0, SEEK_SET); + + std::string data(size, '\0'); + size_t read_size = fread(data.data(), 1, size, f); + fclose(f); + if (read_size != size) { + std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n"; + return {}; + } + + return data; +} + +void write_binary_file(const std::string& path, const std::string& content) { + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n"; + return; + } + + size_t write_size = fwrite(content.data(), 1, content.size(), f); + fclose(f); + if (write_size != content.size()) { + std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n"; + return; + } +} + +void write_file_if_changed(const std::string& path, const std::string& content) { + std::string existing = read_binary_file(path, true); + if (existing != content) { + write_binary_file(path, content); + } +} + + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; +static bool generate_dep_file = true; -void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); - std::string out_fname = join_paths(output_dir, name + ".spv"); - std::string in_path = join_paths(input_dir, in_fname); +void decrement_compile_count(uint32_t * count) { + if (count) { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + compile_count_cond.notify_all(); + } +} +using compile_count_guard = std::unique_ptr; + +compile_count_guard acquire_compile_slot() { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = 16; + std::unique_lock guard(compile_count_mutex); + compile_count_cond.wait(guard, [N] { return compile_count < N; }); + compile_count++; + return compile_count_guard(&compile_count, &decrement_compile_count); +} + +void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) { std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 @@ -249,11 +321,17 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; #ifdef _WIN32 - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""}; #else - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path}; #endif + if (dep_file) { + cmd.push_back("-MD"); + cmd.push_back("-MF"); + cmd.push_back("\"" + target_cpp + ".d\""); + } + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO cmd.push_back("-g"); #endif @@ -281,17 +359,23 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c return; } + if (dep_file) { + // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt + std::string dep = read_binary_file(target_cpp + ".d", true); + if (!dep.empty()) { + size_t pos = dep.find(out_path); + if (pos != std::string::npos) { + dep.replace(pos, out_path.length(), target_cpp); + } + write_binary_file(target_cpp + ".d", dep); + } + } + std::lock_guard guard(lock); - shader_fnames.push_back(std::make_pair(name, out_fname)); + shader_fnames.push_back(std::make_pair(name, out_path)); } catch (const std::exception& e) { std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; } - { - std::lock_guard guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; - } - compile_count_cond.notify_all(); } std::map merge_maps(const std::map& a, const std::map& b) { @@ -301,18 +385,24 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - { - // wait until fewer than N compiles are in progress. - // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. - uint32_t N = 16; - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_path = join_paths(output_dir, name + ".spv"); + + if (input_filepath == "") { + // No input source to compile, only generate header for all shaders + shader_fnames.push_back(std::pair(name, out_path)); + return; + } else if (basename(input_filepath) != source) { + // Only compile shader variants matching the input filename + return; } - compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); + + compile_count_guard slot = acquire_compile_slot(); + compiles.push_back(std::async( + string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot))); + // Don't write the same dep file from multiple processes + generate_dep_file = false; } void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { @@ -485,7 +575,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; std::map base_dict = {{"FLOAT_TYPE", "float"}}; // matmul @@ -837,11 +926,11 @@ void process_shaders() { } void write_output_files() { - FILE* hdr = fopen(target_hpp.c_str(), "w"); - FILE* src = fopen(target_cpp.c_str(), "w"); + std::stringstream hdr = make_generic_stringstream(); + std::stringstream src = make_generic_stringstream(); - fprintf(hdr, "#include \n\n"); - fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + hdr << "#include \n\n"; + src << "#include \"" << basename(target_hpp) << "\"\n\n"; std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { @@ -853,91 +942,85 @@ void write_output_files() { const std::string& path = pair.second; #endif - FILE* spv = fopen(path.c_str(), "rb"); - if (!spv) { - std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } + hdr << "extern const uint64_t " << name << "_len;\n"; + hdr << "extern const unsigned char " << name << "_data[];\n\n"; - fseek(spv, 0, SEEK_END); - size_t size = ftell(spv); - fseek(spv, 0, SEEK_SET); + if (input_filepath != "") { + std::string data = read_binary_file(path); + if (data.empty()) { + continue; + } - std::vector data(size); - size_t read_size = fread(data.data(), 1, size, spv); - fclose(spv); - if (read_size != size) { - std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } - - fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); - fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); - - fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); - for (size_t i = 0; i < size; ++i) { - fprintf(src, "0x%02x,", data[i]); - if ((i + 1) % 12 == 0) fprintf(src, "\n"); - } - fprintf(src, "\n};\n\n"); - - if (!no_clean) { - std::remove(path.c_str()); + src << "const uint64_t " << name << "_len = " << data.size() << ";\n"; + src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex; + auto bytes = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.size(); ++i) { + src << "0x" << static_cast(bytes[i]) << ","; + if ((i + 1) % 12 == 0) src << "\n"; + } + src << std::dec << "\n};\n\n"; } } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); - std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; - std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + + std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; + if (basename(input_filepath) != op_file) { + continue; + } + std::stringstream data = make_generic_stringstream(); + std::stringstream len = make_generic_stringstream(); + data << "const void * " << op << "_data[2][2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t1 = 0; t1 < 2; ++t1) { if (t1 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t2 = 0; t2 < 2; ++t2) { if (t2 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t rte = 0; rte < 2; ++rte) { if (rte == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } - data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - data += "_data,"; - len += "_len,"; + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + data << "_data,"; + len << "_len,"; if (rte == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t2 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t1 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t0 == 1) { - data += "};\n"; - len += "};\n"; + data << "};\n"; + len << "};\n"; } } - fputs(data.c_str(), src); - fputs(len.c_str(), src); + src << data.str(); + src << len.str(); } std::vector btypes = {"f16", "f32"}; @@ -951,20 +1034,25 @@ void write_output_files() { if (btype == "q8_1" && !is_legacy_quant(tname)) { continue; } - fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str()); - fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str()); - std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n"; - std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n"; - fputs(data.c_str(), src); - fputs(len.c_str(), src); + hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } - fclose(hdr); - fclose(src); -} + if (input_filepath == "") { + write_file_if_changed(target_hpp, hdr.str()); + } + if (target_cpp != "") { + write_binary_file(target_cpp, src.str()); + } } +} // namespace + int main(int argc, char** argv) { std::map args; for (int i = 1; i < argc; ++i) { @@ -982,8 +1070,8 @@ int main(int argc, char** argv) { if (args.find("--glslc") != args.end()) { GLSLC = args["--glslc"]; // Path to glslc } - if (args.find("--input-dir") != args.end()) { - input_dir = args["--input-dir"]; // Directory containing shader sources + if (args.find("--source") != args.end()) { + input_filepath = args["--source"]; // The shader source file to compile } if (args.find("--output-dir") != args.end()) { output_dir = args["--output-dir"]; // Directory for containing SPIR-V output @@ -994,14 +1082,6 @@ int main(int argc, char** argv) { if (args.find("--target-cpp") != args.end()) { target_cpp = args["--target-cpp"]; // Path to generated cpp file } - if (args.find("--no-clean") != args.end()) { - no_clean = true; // Keep temporary SPIR-V files in output-dir after build - } - - if (!directory_exists(input_dir)) { - std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; - return EXIT_FAILURE; - } if (!directory_exists(output_dir)) { if (!create_directory(output_dir)) { From af51bbab88bb28429abba6d660c96aadec2e2da9 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Sat, 4 Oct 2025 12:49:16 +0300 Subject: [PATCH 028/104] rpc : add support for multiple devices (llama/16276) * rpc : add support for multiple devices Allow rpc-server to expose multiple devices from a single endpoint. Change RPC protocol to include device identifier where needed. closes: #15210 * fixes * use ggml_backend_reg_t * address review comments * fix llama-bench backend report * address review comments, change device naming * fix cmd order --- ggml/include/ggml-backend.h | 2 + ggml/include/ggml-rpc.h | 17 +- ggml/src/ggml-backend-impl.h | 3 - ggml/src/ggml-rpc/ggml-rpc.cpp | 401 +++++++++++++++++++++++---------- 4 files changed, 289 insertions(+), 134 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 62b6d65e..f1b74078 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -215,6 +215,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1e674112..72eff002 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 07784d6f..6792ba98 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -209,9 +209,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f99681c8..1a8739e7 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -105,9 +105,12 @@ enum rpc_cmd { RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, + RPC_CMD_DEVICE_COUNT, RPC_CMD_COUNT, }; +static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); + // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; @@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp { uint8_t patch; }; +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + struct rpc_msg_get_alloc_size_req { + uint32_t device; rpc_tensor tensor; }; @@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req { }; struct rpc_msg_alloc_buffer_req { + uint32_t device; uint64_t size; }; @@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp { uint64_t remote_size; }; +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + struct rpc_msg_get_alignment_rsp { uint64_t alignment; }; +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + struct rpc_msg_get_max_size_rsp { uint64_t max_size; }; @@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp { uint8_t result; }; +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; @@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() { struct ggml_backend_rpc_buffer_type_context { std::string endpoint; + uint32_t device; std::string name; - size_t alignment; - size_t max_size; + size_t alignment; + size_t max_size; }; struct ggml_backend_rpc_context { std::string endpoint; + uint32_t device; std::string name; }; @@ -653,7 +676,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); @@ -669,9 +692,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back } } -static size_t get_alignment(const std::shared_ptr & sock) { +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; rpc_msg_get_alignment_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.alignment; } @@ -681,9 +705,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ return buft_ctx->alignment; } -static size_t get_max_size(const std::shared_ptr & sock) { +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; rpc_msg_get_max_size_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.max_size; } @@ -700,7 +725,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty auto sock = get_socket(buft_ctx->endpoint); rpc_msg_get_alloc_size_req request; - + request.device = buft_ctx->device; request.tensor = serialize_tensor(tensor); rpc_msg_get_alloc_size_rsp response; @@ -754,7 +779,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector & tensors, tensors.push_back(serialize_tensor(tensor)); } -static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { uint32_t n_nodes = cgraph->n_nodes; std::vector tensors; std::unordered_set visited; @@ -762,24 +787,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o add_tensor(cgraph->nodes[i], tensors, visited); } // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | uint32_t n_tensors = tensors.size(); - int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); output.resize(output_size, 0); - memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { - memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); } - uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); - *out_ntensors = n_tensors; - rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; - serialize_graph(cgraph, input); + serialize_graph(rpc_ctx->device, cgraph, input); rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); @@ -804,12 +834,13 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_optimize = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { static std::mutex mutex; std::lock_guard lock(mutex); + std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; // NOTE: buffer types are allocated and never freed; this is by design static std::unordered_map buft_map; - auto it = buft_map.find(endpoint); + auto it = buft_map.find(buft_name); if (it != buft_map.end()) { return it->second; } @@ -818,34 +849,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); return nullptr; } - size_t alignment = get_alignment(sock); - size_t max_size = get_max_size(sock); + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .device = */ device, + /* .name = */ buft_name, /* .alignment = */ alignment, /* .max_size = */ max_size }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ buft_ctx }; - buft_map[endpoint] = buft; + buft_map[buft_name] = buft; return buft; } -ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .iface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ ctx }; return backend; @@ -855,37 +889,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } -static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; rpc_msg_get_device_memory_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; } -void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { auto sock = get_socket(endpoint); if (sock == nullptr) { *free = 0; *total = 0; return; } - get_device_memory(sock, free, total); + get_device_memory(sock, device, free, total); } // RPC server-side implementation class rpc_server { public: - rpc_server(ggml_backend_t backend, const char * cache_dir) - : backend(backend), cache_dir(cache_dir) { + rpc_server(std::vector backends, const char * cache_dir) + : backends(std::move(backends)), cache_dir(cache_dir) { } ~rpc_server(); void hello(rpc_msg_hello_rsp & response); - void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); - void get_alignment(rpc_msg_get_alignment_rsp & response); - void get_max_size(rpc_msg_get_max_size_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); @@ -906,7 +942,7 @@ private: std::unordered_map & tensor_map); - ggml_backend_t backend; + std::vector backends; const char * cache_dir; std::unordered_set buffers; }; @@ -919,6 +955,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) { } bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } ggml_backend_buffer_type_t buft; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -935,10 +975,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); return false; } - LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); if (tensor->buffer == nullptr) { //No buffer allocated. - buft = ggml_backend_get_default_buffer_type(backend); + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); } else { buft = tensor->buffer->buft; } @@ -948,33 +988,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ return true; } -void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); response.remote_ptr = 0; response.remote_size = 0; if (buffer != nullptr) { response.remote_ptr = reinterpret_cast(buffer); response.remote_size = buffer->size; - LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); } + return true; } -void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t alignment = ggml_backend_buft_get_alignment(buft); - LOG_DBG("[%s] alignment: %lu\n", __func__, alignment); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); response.alignment = alignment; + return true; } -void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t max_size = ggml_backend_buft_get_max_size(buft); - LOG_DBG("[%s] max_size: %lu\n", __func__, max_size); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); response.max_size = max_size; + return true; } bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { @@ -1332,23 +1388,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id, bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | - if (input.size() < sizeof(uint32_t)) { + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { return false; } uint32_t n_nodes; - memcpy(&n_nodes, input.data(), sizeof(n_nodes)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { return false; } - const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); uint32_t n_tensors; - memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { return false; } - const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); - LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); @@ -1380,7 +1446,7 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } - ggml_status status = ggml_backend_graph_compute(backend, graph); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); response.result = status; return true; } @@ -1391,9 +1457,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, - sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend, cache_dir); +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; @@ -1425,13 +1491,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, // HELLO command is handled above return; } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; - server.alloc_buffer(request, response); + if (!server.alloc_buffer(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1452,22 +1531,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_ALIGNMENT: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_alignment_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; - server.get_alignment(response); + if (!server.get_alignment(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_max_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; - server.get_max_size(response); + if (!server.get_max_size(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1593,12 +1678,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_DEVICE_MEMORY: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + auto dev_id = request.device; + if (dev_id >= backends.size()) { return; } rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem; - response.total_mem = total_mem; + response.free_mem = free_mem[dev_id]; + response.total_mem = total_mem[dev_id]; + LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, + response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1612,16 +1704,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { + if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + std::vector free_mem_vec(free_mem, free_mem + n_devices); + std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, RPC_PROTO_PATCH_VERSION); printf(" endpoint : %s\n", endpoint); printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); - printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } std::string host; int port; @@ -1649,22 +1766,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint fprintf(stderr, "Failed to accept client connection\n"); return; } - printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); printf("Client connection closed\n"); fflush(stdout); } #ifdef _WIN32 WSACleanup(); #endif + for (auto backend : backends) { + ggml_backend_free(backend); + } } // device interface struct ggml_backend_rpc_device_context { std::string endpoint; + uint32_t device; std::string name; + std::string description; }; static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { @@ -1676,15 +1798,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ctx->name.c_str(); + return ctx->description.c_str(); } static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - - GGML_UNUSED(dev); + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { @@ -1710,7 +1830,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_init(ctx->endpoint.c_str()); + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(params); } @@ -1718,7 +1838,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(dev); } @@ -1736,7 +1856,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; - return buft_ctx->endpoint == dev_ctx->endpoint; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; } static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { @@ -1759,28 +1879,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { // backend reg interface -static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { - return "RPC"; +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; - GGML_UNUSED(reg); +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { - return 0; - - GGML_UNUSED(reg); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { - GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { - return (void *)ggml_backend_rpc_add_device; + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; } if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { return (void *)ggml_backend_rpc_start_server; @@ -1807,30 +1933,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) { return &ggml_backend_rpc_reg; } -ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { - static std::unordered_map dev_map; - - static std::mutex mutex; - std::lock_guard lock(mutex); - - if (dev_map.find(endpoint) != dev_map.end()) { - return dev_map[endpoint]; - } - - ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", - }; - - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_rpc_device_i, - /* .reg = */ ggml_backend_rpc_reg(), - /* .context = */ ctx, - }; - - dev_map[endpoint] = dev; - - return dev; +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; } +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + static std::unordered_map reg_map; + static std::mutex mutex; + static uint32_t dev_id = 0; + std::lock_guard lock(mutex); + if (reg_map.find(endpoint) != reg_map.end()) { + return reg_map[endpoint]; + } + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t ind = 0; ind < dev_count; ind++) { + std::string dev_name = "RPC" + std::to_string(dev_id); + std::string dev_desc = std::string(endpoint); + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + dev_id++; + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx + }; + reg_map[endpoint] = reg; + return reg; +} + + GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) From 93882335a8a0b3435869d1d47b506dfccea9d044 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Sat, 4 Oct 2025 16:22:45 +0300 Subject: [PATCH 029/104] rpc : check src buffer when copying tensor (llama/16421) Only dst buffer is guaranteed to be an RPC buffer. Add check for the src one. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 37 ++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1a8739e7..aad48d62 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con RPC_STATUS_ASSERT(status); } +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - // check if src and dst are on the same server - ggml_backend_buffer_t src_buffer = src->buffer; - ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; - ggml_backend_buffer_t dst_buffer = dst->buffer; - ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sock != dst_ctx->sock) { - return false; + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; } - ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - rpc_msg_copy_tensor_req request; - request.src = serialize_tensor(src); - request.dst = serialize_tensor(dst); - rpc_msg_copy_tensor_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return response.result; + return false; } static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { From 2ca8fa37fa6708f890c1b0e3d520cd4e8cfb645d Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Oct 2025 20:04:27 +0000 Subject: [PATCH 030/104] vulkan: use a more appropriate amount of threads when generating shaders (llama/16418) * use a more flexible amount of threads * fix windows compile and 0 thread case * nominmax --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e2726f1f..f0cc24ff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1,5 +1,3 @@ - - #include #include #include @@ -22,6 +20,7 @@ #include #ifdef _WIN32 + #define NOMINMAX #include #include // For _mkdir on Windows #else @@ -306,7 +305,7 @@ using compile_count_guard = std::unique_ptr guard(compile_count_mutex); compile_count_cond.wait(guard, [N] { return compile_count < N; }); compile_count++; From b8bdf061829942e7843e74e0988ed0e417ae10f1 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sat, 4 Oct 2025 20:59:31 -0700 Subject: [PATCH 031/104] ggml webgpu: actually add softmax, fix rms_norm offset (llama/16400) * implement soft_max * Fix soft_max data race * Temporary fix, wait on each submit --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 8 ++++++++ ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 2 +- ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index de68c568..e795ca3f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -424,6 +424,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx->staged_param_bufs.push_back(params_bufs); if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { ggml_backend_webgpu_submit_queue(ctx); + ggml_backend_webgpu_wait_on_submission(ctx); } } } @@ -1060,6 +1061,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_SCALE: ggml_webgpu_scale(ctx, src0, node); break; + case GGML_OP_SOFT_MAX: + ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + break; default: return false; } @@ -1806,6 +1810,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SCALE: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SOFT_MAX: + supports_op = op->type == GGML_TYPE_F32; + break; default: break; } @@ -1949,6 +1956,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_rope_pipeline(ctx); ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); + ggml_webgpu_init_soft_max_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index 4f72bb1c..712b921f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i2 = i / params.ne1; let i1 = i % params.ne1; let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let elems = (params.ne0 + wg_size - 1) / wg_size; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl index 64ab576c..c74dc4cc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -300,6 +300,7 @@ fn main(@builtin(workgroup_id) wid: vec3, workgroupBarrier(); } let row_max = scratch[0]; + workgroupBarrier(); var sum = 0.0f; col = lid.x; From 0f29d7c3fa1762e65ee96ee121b983c48ee26f0e Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 6 Oct 2025 14:17:12 +0200 Subject: [PATCH 032/104] ggml-cpu : fix leftover handling in ggml_vec_scale_f32 for SVE (llama/16443) This commit updates the leftover handling in ggml_vec_scale_f32. The motivation for this is that the code currently incorrectly assumes there would be fewer than ggml_f32_epr leftover elements. However, since the main loop processes 2*ggml_f32_epr elements per iteration , there can be up to (2*ggml_f32_epr - 1) leftover elements. The original single-pass leftover code could only process ggml_f32_epr elements, leaving some elements unscaled. Example scenario with 256-bit SVE: ``` ggml_f32_epr = 8 (elements per register) ggml_f32_step = 16 (two registers per iteration) n = 25 np = 16 leftovers = 9 elements (16-24) Original : processes only elements 16-23, misses element 24 This commit : loop processes elements 16-23, then element 24 ``` Refs: https://github.com/ggml-org/llama.cpp/actions/runs/18070620247/job/51419855630 --- ggml/src/ggml-cpu/vec.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 341e64e6..f95ca94e 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -654,11 +654,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } // leftovers // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); + for (int i = np; i < n; i += ggml_f32_epr) { + svbool_t pg = svwhilelt_b32(i, n); + ay1 = svld1_f32(pg, y + i); ay1 = svmul_f32_m(pg, ay1, vx); - svst1_f32(pg, y + np, ay1); + svst1_f32(pg, y + i, ay1); } #elif defined(__riscv_v_intrinsic) for (int i = 0, avl; i < n; i += avl) { From 0e431b3cea6b2eca455f1bc2816b36ea6c6a0c88 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Oct 2025 16:05:27 +0300 Subject: [PATCH 033/104] ggml : fix unaligned access in AMX code (llama/16315) --- ggml/src/ggml-cpu/amx/amx.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 867e158d..895a5713 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous is_contiguous_2d(op->src[1]) && // src1 must be contiguous op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { // src1 must be host buffer From 1a4116f9423602fea82099798b73ca2fd1f3885a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 08:21:40 +0300 Subject: [PATCH 034/104] metal : various optimizations + refactoring (llama/16446) * metal : ssm_scan minor opts * metal : get_rows optimize * metal : cpy optimize * metal : ssm_conv opt * metal : ssm_scan simplify * metal : ssm_Scan opt --- ggml/src/ggml-metal/ggml-metal-device.cpp | 22 +- ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 18 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 78 +-- ggml/src/ggml-metal/ggml-metal.metal | 588 +++++++--------------- 5 files changed, 258 insertions(+), 452 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8..d9e92044 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar char base[256]; char name[256]; - snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); @@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + char base[256]; char name[256]; - if (op->src[3]->ne[0] == 1) { - snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); - } else { - snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); - } - snprintf(name, 256, "%s", base); + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { @@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 523f9d71..95279730 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423..908e2e1c 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -178,6 +178,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -572,32 +573,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c..7497d7c1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, }; + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); return 1; } @@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); @@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, /*.nb31 =*/ nb31, /*.nb41 =*/ nb41, /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, /*.nb43 =*/ nb43, /*.nb51 =*/ nb51, /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); - if (ne30 == 1) { - // Mamba-2 - ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); - } else { - GGML_ASSERT(d_inner == 1); - ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); - } + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); return 1; } @@ -1273,26 +1287,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(op->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); } - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; // TODO: relax this constraint in the future if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nrptg--; @@ -1300,10 +1311,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } } - nth = std::min(nth, nk00); + nth = std::min(nth, nk0); ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 96df6f0c..f454cead 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -5770,21 +5670,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5795,190 +5691,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -5986,11 +5762,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -6002,10 +5779,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -7765,66 +7544,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -8310,12 +8083,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; From 6cf0c21b094771237e9ba9da7853d6f7bfca90f9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 08:22:35 +0300 Subject: [PATCH 035/104] tests : add -INF blocks to the KQ mask in the FA tests (llama/16380) * tests : add -INF blocks to the KQ mask in the FA tests * cont : bump -INF block size to 64 Co-authored-by: Jeff Bolz * ggml : prevent division by zero in FA CPU op --------- Co-authored-by: Jeff Bolz --- ggml/src/ggml-cpu/ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6275c830..8e1a2de1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices From 4bce4fa5e93b129165402450489061a9412c33e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 08:23:30 +0300 Subject: [PATCH 036/104] metal : add support for non-padded FA KV (llama/16148) * metal : pad K, V and Mask when needed * cont : simplify * cuda : add TODO about KV padding requirement * metal : add comments * metal : remove mask padding requirement --- ggml/src/ggml-cuda/fattn.cu | 6 + ggml/src/ggml-metal/ggml-metal-device.cpp | 60 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 8 + ggml/src/ggml-metal/ggml-metal-impl.h | 31 ++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 243 +++++++++++++++++----- ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 175 ++++++++++++++-- 8 files changed, 458 insertions(+), 71 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index d7736d36..0c8e7b3e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int cc = ggml_cuda_info().devices[device].cc; + // TODO: temporary until support is extended + // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 + if (K->ne[1] % FATTN_KQ_STRIDE != 0) { + return BEST_FATTN_KERNEL_NONE; + } + switch (K->ne[0]) { case 64: case 128: diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d9e92044..46cc5134 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -924,6 +924,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, @@ -931,6 +975,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", "flash_attn_ext", ggml_type_name(op->src[1]->type), dk, dv); - snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, + bc_mask, ns10, ns20, nsg); @@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); @@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( dk, dv); - snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, ns10, ns20, nsg, nwg); @@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a..ef049507 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, @@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( @@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 908e2e1c..1524b3ab 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -69,11 +69,12 @@ #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT 200 +#define FC_FLASH_ATTN_EXT_VEC 300 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400 +#define FC_MUL_MV 500 +#define FC_MUL_MM 600 // kernel argument structs // @@ -244,6 +245,24 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + typedef struct { int32_t ne01; int32_t ne02; @@ -262,6 +281,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -296,6 +316,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7497d7c1..125cc64d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); GGML_TENSOR_LOCALS( int64_t, ne, node, ne); GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); @@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ggml_is_contiguous(node->src[1]), node->src[1]->name); } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } if (node) { GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, node->name); @@ -1889,20 +1901,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { return (ne01 < 20) && (ne00 % 32 == 0); } +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % 32 != 0; + + if (has_kvpad) { + res += 32*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % 64 != 0; + + if (has_kvpad) { + res += 64*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); - const int64_t nwg = 32; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); - const int64_t ne01 = op->src[0]->ne[1]; - const int64_t ne02 = op->src[0]->ne[2]; - const int64_t ne03 = op->src[0]->ne[3]; - const int64_t ne20 = op->src[2]->ne[0]; + size_t res = 0; - // temp buffer for writing the results from each workgroup - // - ne20: the size of the Value head - // - + 2: the S and M values for each intermediate result - return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { @@ -1924,8 +1985,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, nb, op, nb); - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == op->src[2]->type); @@ -1935,8 +1995,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne12 == ne22); GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); - GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); float scale; float max_bias; @@ -1963,6 +2023,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_tmp = bid_pad; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! @@ -1972,6 +2046,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2021,6 +2137,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2037,24 +2154,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2070,6 +2180,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2134,6 +2286,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2150,25 +2303,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2176,23 +2321,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op); - ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e..6a6d8a79 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); // return true if we should use the FA vector kernel for this op bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e11555a7..e53f37b2 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ } break; case GGML_OP_FLASH_ATTN_EXT: { - if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { - res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); - } + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f454cead..c52c6b48 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4349,10 +4349,83 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -4399,6 +4472,7 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4523,13 +4597,58 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C) { + int ic = ic0; + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } @@ -4557,7 +4676,7 @@ void kernel_flash_attn_ext_impl( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -4629,7 +4748,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -4751,7 +4870,7 @@ void kernel_flash_attn_ext_impl( { auto sst = ss; - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; @@ -4793,7 +4912,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4937,13 +5056,14 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5063,6 +5183,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -5100,6 +5221,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -5206,11 +5328,37 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -5222,7 +5370,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -5237,7 +5385,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -5335,7 +5483,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -5348,7 +5496,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -5520,13 +5668,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; From 4eea3efc4906edaffeba71b3ce10231323324d82 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 7 Oct 2025 13:48:56 -0700 Subject: [PATCH 037/104] ggml webgpu: profiling, CI updates, reworking of command submission (llama/16452) * Add profiling * More detailed profiling * Rework command submission to avoid global locks * Update wait handling * try new method of waiting on futures * Add serializing of command submission in some cases * Add new pool for timestamp queries and clean up logging * Serialize command submission in CI and leave a TODO note * Update webgpu CI * Add myself as WebGPU codeowner * Deadlock avoidance * Leave WebGPU/Vulkan CI serialized * Fix divide by 0 * Fix logic in division by inflight_threads * Update CODEOWNERS and remove serialize submit option --- ggml/CMakeLists.txt | 3 + ggml/src/ggml-webgpu/CMakeLists.txt | 8 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 720 ++++++++++++------ .../wgsl-shaders/mul_mat.tmpl.wgsl | 2 +- 4 files changed, 491 insertions(+), 242 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6ce52ffc..73032be6 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -222,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_WEBGPU "ggml: use WebGPU" OFF) option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) +option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) +option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) + option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 78a985a4..c6a95d51 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -50,5 +50,13 @@ if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) endif() +if (GGML_WEBGPU_CPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_CPU_PROFILE=1) +endif() + +if (GGML_WEBGPU_GPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_GPU_PROFILE=1) +endif() + target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR}) target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET}) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e795ca3f..05e16cd4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -11,10 +11,12 @@ #include +#include #include #include #include #include +#include #include #include @@ -25,12 +27,44 @@ # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG +#ifdef GGML_WEBGPU_CPU_PROFILE +// total timing (aggregated) +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \ + auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_total_time_##id = \ + std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \ + (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; + +// fine-grained timing (not included in totals) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \ + auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_detail_time_##id = \ + std::chrono::duration(cpu_detail_end_##id - cpu_detail_start_##id).count(); \ + (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id; +#else +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) +#endif // GGML_WEBGPU_CPU_PROFILE + +#ifdef GGML_WEBGPU_GPU_PROFILE +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +#endif + /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 -#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_MUL_MAT_WG_SIZE 256 +#define WEBGPU_NUM_PARAM_BUFS 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 @@ -66,6 +100,11 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -112,6 +151,83 @@ struct webgpu_buf_pool { } }; +#ifdef GGML_WEBGPU_GPU_PROFILE +struct webgpu_gpu_profile_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + wgpu::QuerySet query_set; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_gpu_profile_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); + // Create a query set for 2 timestamps + wgpu::QuerySetDescriptor ts_query_set_desc = {}; + + ts_query_set_desc.type = wgpu::QueryType::Timestamp; + ts_query_set_desc.count = 2; + wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); + + free.push_back({ host_buf, dev_buf, ts_query_set }); + } + } + + webgpu_gpu_profile_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_gpu_profile_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + bufs.query_set.Destroy(); + } + free.clear(); + } +}; +#endif + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; +}; + +struct webgpu_command { + wgpu::CommandBuffer commands; + webgpu_pool_bufs params_bufs; + std::optional set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs timestamp_query_bufs; + std::string pipeline_name; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -125,45 +241,50 @@ struct webgpu_context_struct { uint32_t max_wg_size_x; std::recursive_mutex mutex; + std::atomic_uint inflight_threads = 0; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - wgpu::ComputePipeline memset_pipeline; - wgpu::ComputePipeline mul_mat_pipeline[30][2]; - wgpu::ComputePipeline set_rows_pipeline; - wgpu::ComputePipeline get_rows_pipeline[30]; - wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type - wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace - wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace - wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split - wgpu::ComputePipeline scale_pipeline[2]; // inplace - wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + webgpu_pipeline memset_pipeline; + webgpu_pipeline mul_mat_pipeline[30][2]; + webgpu_pipeline set_rows_pipeline; + webgpu_pipeline get_rows_pipeline[30]; + webgpu_pipeline get_rows_f32_no_vec_pipeline; + webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type + webgpu_pipeline add_pipeline[2][2]; // type, inplace + webgpu_pipeline sub_pipeline[2][2]; // type, inplace + webgpu_pipeline mul_pipeline[2][2]; // type, inplace + webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline rms_norm_pipeline[2]; // inplace + webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace + webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split + webgpu_pipeline scale_pipeline[2]; // inplace + webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU wgpu::Buffer get_tensor_staging_buf; - // Command buffers which need to be submitted - std::vector staged_command_bufs; - - // Parameter buffers associated with the staged command buffers - std::vector staged_param_bufs; - // Buffers associated with set_rows operations, used to store potential errors - std::vector staged_set_row_error_bufs; - - std::vector callback_futures; - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif }; typedef std::shared_ptr webgpu_context; @@ -199,12 +320,10 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ static void ggml_webgpu_create_pipeline(wgpu::Device & device, - wgpu::ComputePipeline & pipeline, + webgpu_pipeline & pipeline, const char * shader_code, const char * label, const std::vector & constants = {}) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); - wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -222,7 +341,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } - pipeline = device.CreateComputePipeline(&pipeline_desc); + pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } static void ggml_webgpu_create_buffer(wgpu::Device & device, @@ -230,8 +349,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, size_t size, wgpu::BufferUsage usage, const char * label) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); - wgpu::BufferDescriptor buffer_desc; buffer_desc.size = size; buffer_desc.usage = usage; @@ -247,83 +364,35 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { - std::lock_guard lock(ctx->mutex); - if (ctx->callback_futures.empty()) { - // no existing callbacks, wait on queue submission - ctx->instance.WaitAny( - ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); - } else { - // WebGPU implementations may limit the number of futures that can be waited on at once, - // so wait in batches (64 is what Dawn supports). - for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) { - size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size()); - ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX); +static void ggml_backend_webgpu_wait(webgpu_context & ctx, + std::vector & futures, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, + // inflight_max may be 0, meaning that we must wait on all futures. + uint64_t timeout_ms = block ? UINT64_MAX : 0; + uint inflight_threads = ctx->inflight_threads; + uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + while (futures.size() >= inflight_max && futures.size() > 0) { + ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); + futures.erase(futures.begin()); + } + size_t i = 0; + while (i < futures.size()) { + auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + futures.erase(futures.begin() + i); + break; + case wgpu::WaitStatus::TimedOut: + i++; + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; } - ctx->callback_futures.clear(); - } -} - -static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { - std::lock_guard lock(ctx->mutex); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()"); - if (ctx->staged_command_bufs.empty()) { - // Nothing to submit - return; - } - ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); - - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (ctx->staged_set_row_error_bufs.size() > 0) { - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (auto & error_bufs : ctx->staged_set_row_error_bufs) { - // Copy the error buffer to the host buffer - encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize()); - } - wgpu::CommandBuffer commands = encoder.Finish(); - ctx->queue.Submit(1, &commands); - } - - ctx->staged_command_bufs.clear(); - std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); - std::vector staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs); - - // Free the staged parameter buffers once the submission completes - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - ctx->param_buf_pool.free_bufs(staged_param_bufs); - }); - ctx->callback_futures.push_back({ p_f }); - - // Check for errrors in SET_ROWS operations - for (auto & error_bufs : staged_set_row_error_bufs) { - wgpu::Future f = error_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ error_bufs }); - } - }); - ctx->callback_futures.push_back({ f }); } } @@ -347,7 +416,6 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { - ggml_backend_webgpu_submit_queue(ctx); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -364,13 +432,85 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, - wgpu::ComputePipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - const char * bind_group_label = nullptr, - bool submit_and_wait = false) { +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) { + std::vector command_buffers; + std::vector params_bufs; + std::vector set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector> pipeline_name_and_ts_bufs; +#endif + + for (const auto & command : commands) { + command_buffers.push_back(command.commands); + params_bufs.push_back(command.params_bufs); + if (command.set_rows_error_bufs) { + set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); + } + } + ctx->queue.Submit(command_buffers.size(), command_buffers.data()); + + std::vector futures; + + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); + } + // Free the staged buffers + ctx->param_buf_pool.free_bufs({ params_bufs }); + }); + futures.push_back({ p_f }); + + for (const auto & bufs : set_rows_error_bufs) { + wgpu::Future f = bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); + } else { + const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->set_rows_error_buf_pool.free_bufs({ bufs }); + } + }); + futures.push_back({ f }); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + for (const auto & command : commands) { + auto label = command.pipeline_name; + auto ts_bufs = command.timestamp_query_bufs; + + wgpu::Future f = ts_bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); + } else { + const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); + // WebGPU timestamps are in ns; convert to ms + double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; + ctx->shader_gpu_time_ms[label] += elapsed_ms; + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); + } + }); + futures.push_back({ f }); + } +#endif + return { futures }; +} + +static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -388,45 +528,58 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & .size = params_bufs.dev_buf.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entries = bind_group_entries.data(); - if (bind_group_label) { - bind_group_desc.label = bind_group_label; - } + bind_group_desc.label = pipeline.name.c_str(); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // --- Profiling: GPU timestamp queries --- + // Allocate a timestamp query buffer (2 timestamps: start/end) + webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); +#else wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); +#endif + pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - if (submit_and_wait) { - // Submit and wait immediately - ctx->queue.Submit(1, &commands); - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); - } - ctx->param_buf_pool.free_bufs({ params_bufs }); - }), - UINT64_MAX); - } else { - // Lock the context mutex when pushing to the staging vectors. - std::lock_guard lock(ctx->mutex); - // Enqueue commands and only submit if we have enough staged commands - ctx->staged_command_bufs.push_back(commands); - ctx->staged_param_bufs.push_back(params_bufs); - if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - ggml_backend_webgpu_submit_queue(ctx); - ggml_backend_webgpu_wait_on_submission(ctx); - } + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Resolve the query set into the device buffer + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); +#endif + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); } + + wgpu::CommandBuffer commands = encoder.Finish(); + webgpu_command result = {}; + result.commands = commands; + result.params_bufs = params_bufs; + result.set_rows_error_bufs = set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + result.timestamp_query_bufs = ts_bufs; + result.pipeline_name = pipeline.name; +#endif + return result; } static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, @@ -440,7 +593,10 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, }; size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true); + + webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; + ggml_backend_webgpu_wait(ctx, futures); } /** End WebGPU Actions */ @@ -456,8 +612,48 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); - // TODO: cleanup +#ifdef GGML_WEBGPU_CPU_PROFILE + std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; + double total_cpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + total_cpu += kv.second; + } + std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; + std::cout << "ggml_webgpu: cpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } + if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; + } + for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; + double total_gpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + total_gpu += kv.second; + } + std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; + std::cout << "\nggml_webgpu: gpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE) + std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; +#endif + +#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) GGML_UNUSED(ctx); +#endif } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -490,7 +686,7 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } -static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -519,14 +715,16 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); } -static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { - return; + return std::nullopt; } webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); @@ -569,13 +767,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - std::lock_guard lock(ctx->mutex); - ctx->staged_set_row_error_bufs.push_back(error_bufs); - - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); } -static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -610,14 +808,17 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; - wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type]; + webgpu_pipeline pipeline = ctx->get_rows_pipeline[src->type]; if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { pipeline = ctx->get_rows_f32_no_vec_pipeline; } - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), @@ -654,16 +855,15 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } -static void ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - wgpu::ComputePipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + webgpu_pipeline & pipeline, + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -701,10 +901,10 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -736,15 +936,14 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src), - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src)); } -static void ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -822,13 +1021,13 @@ static void ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { const int split = (src1 != nullptr); std::vector params = { @@ -875,13 +1074,13 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -916,15 +1115,14 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x); } -static void ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here const int has_sink = (src2 != nullptr); @@ -989,14 +1187,14 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst), ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + ggml_nrows(dst)); } -// Returns true if node has enqueued work into the queue, false otherwise -static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +// Returns the encoded command, or std::nullopt if the operation is a no-op +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { - return false; + return std::nullopt; } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); @@ -1011,63 +1209,49 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_RESHAPE: - return false; + return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - ggml_webgpu_cpy(ctx, src0, node); - break; + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET_ROWS: - ggml_webgpu_set_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - ggml_webgpu_get_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - ggml_webgpu_mul_mat(ctx, src0, src1, node); - break; + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); } case GGML_OP_SUB: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); } case GGML_OP_MUL: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); } case GGML_OP_DIV: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); } case GGML_OP_RMS_NORM: - ggml_webgpu_rms_norm(ctx, src0, node); - break; + return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: - ggml_webgpu_rope(ctx, src0, src1, src2, node); - break; + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - ggml_webgpu_glu(ctx, src0, src1, node); - break; + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - ggml_webgpu_scale(ctx, src0, node); - break; + return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: - ggml_webgpu_soft_max(ctx, src0, src1, src2, node); - break; + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); default: - return false; + return std::nullopt; } - return true; } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -1076,13 +1260,35 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); webgpu_context ctx = backend_ctx->webgpu_ctx; + WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); + + ctx->inflight_threads++; + + std::vector commands; + std::vector futures; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + commands.push_back(*cmd); + } + // compute the batch size based on the number of inflight threads + uint inflight_threads = ctx->inflight_threads; + uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + if (commands.size() >= batch_size) { + futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); + // Process events and check for completed submissions + ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx, futures, false); + commands.clear(); + } } - - ggml_backend_webgpu_submit_queue(ctx); - ggml_backend_webgpu_wait_on_submission(ctx); - + if (!commands.empty()) { + webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); + futures.push_back(new_futures); + } + ggml_backend_webgpu_wait(ctx, futures); + ctx->inflight_threads--; + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); return GGML_STATUS_SUCCESS; } @@ -1108,7 +1314,6 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* GGML Backend Buffer Interface */ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()"); ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); ctx->buffer.Destroy(); } @@ -1129,6 +1334,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe return; } + WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); @@ -1139,6 +1346,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1148,6 +1356,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; @@ -1170,8 +1379,17 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, remaining_size); } else { // wait for WriteBuffer to complete - ggml_backend_webgpu_wait_on_submission(webgpu_ctx); + webgpu_ctx->instance.WaitAny( + webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -1181,7 +1399,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - + WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; wgpu::Device device = webgpu_ctx->device; @@ -1221,12 +1439,15 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); webgpu_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -1876,6 +2097,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); + WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device); + ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); webgpu_context ctx = reg_ctx->webgpu_ctx; @@ -1902,7 +2125,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; - wgpu::DeviceDescriptor dev_desc; +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); @@ -1916,8 +2143,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, @@ -1939,6 +2166,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries (profiling) + ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); @@ -1983,6 +2219,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .reg = */ reg, /* .context = */ &device_ctx, }; + + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); return &device; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 25e2185d..141db9b3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -870,7 +870,7 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; -@compute @workgroup_size(64) +@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) global_id: vec3) { let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; if (global_id.x >= total) { From 7ef78a72e11289203029420cd26089d2c903538d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 Oct 2025 10:57:53 +0300 Subject: [PATCH 038/104] metal : mark FA blocks (llama/16372) * metal : better unroll in the FA kernels * metal : index FA blocks * tests : restore [no ci] * metal : prevent division by zero in FA kernels * metal : fix -INF detection logic --- ggml/src/ggml-metal/ggml-metal-device.cpp | 48 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 6 + ggml/src/ggml-metal/ggml-metal-impl.h | 29 +++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 113 +++++++++++-- ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 191 +++++++++++++++++----- 7 files changed, 324 insertions(+), 65 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 46cc5134..e23abdda 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -959,7 +959,53 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); - ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index ef049507..1034e4bb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -141,6 +141,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( bool has_mask, int32_t ncpsg); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 1524b3ab..c9dff873 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -70,11 +70,19 @@ // function constants offsets #define FC_FLASH_ATTN_EXT_PAD 100 -#define FC_FLASH_ATTN_EXT 200 -#define FC_FLASH_ATTN_EXT_VEC 300 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400 -#define FC_MUL_MV 500 -#define FC_MUL_MM 600 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -263,6 +271,17 @@ typedef struct { uint64_t nb33; } ggml_metal_kargs_flash_attn_ext_pad; +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 125cc64d..1137e210 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; if (ggml_metal_op_flash_attn_ext_use_vec(op)) { - const bool has_kvpad = ne11 % 32 != 0; + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; if (has_kvpad) { - res += 32*( + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( nb11*ne12*ne13 + nb21*ne22*ne23 + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); } } else { - const bool has_kvpad = ne11 % 64 != 0; + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; if (has_kvpad) { - res += 64*( + res += OP_FLASH_ATTN_EXT_NCPSG*( nb11*ne12*ne13 + nb21*ne22*ne23 + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); @@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { return res; } +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_buffer_id bid_pad = bid_dst; bid_pad.offs += ggml_nbytes(op); - ggml_metal_buffer_id bid_tmp = bid_pad; - bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + const bool has_kvpad = ne11 % ncpsg != 0; if (has_kvpad) { @@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); - ggml_metal_op_concurrency_reset(ctx); + need_sync = true; } else { assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); } + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2164,7 +2242,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_src3, 4); ggml_metal_encoder_set_buffer (enc, bid_src4, 5); ggml_metal_encoder_set_buffer (enc, bid_pad, 6); - ggml_metal_encoder_set_buffer (enc, bid_dst, 7); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2172,14 +2251,16 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + const bool has_kvpad = ne11 % ncpsg != 0; if (has_kvpad) { @@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); - ggml_metal_op_concurrency_reset(ctx); + need_sync = true; } else { assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); } + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 6a6d8a79..d4cb9446 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -40,6 +40,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e53f37b2..7afc881f 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -194,6 +194,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ case GGML_OP_FLASH_ATTN_EXT: { res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c52c6b48..45d91def 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4351,7 +4351,7 @@ kernel void kernel_leaky_relu_f32_4( constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; -constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]]; +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; // pad the last chunk of C elements of k and v into a an extra pad buffer kernel void kernel_flash_attn_ext_pad( @@ -4419,6 +4419,65 @@ kernel void kernel_flash_attn_ext_pad( } } +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; @@ -4473,6 +4532,7 @@ void kernel_flash_attn_ext_impl( device const char * mask, device const char * sinks, device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4538,6 +4598,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -4597,11 +4664,14 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C) { - int ic = ic0; + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } // the last partial chunk uses the pad buffer as source - if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) { + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; @@ -4640,6 +4710,14 @@ void kernel_flash_attn_ext_impl( // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; @@ -4652,6 +4730,9 @@ void kernel_flash_attn_ext_impl( pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -4670,6 +4751,7 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T @@ -4687,26 +4769,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -4715,26 +4795,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -4868,27 +4944,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; - pv += 8*(NS20 - NO*NSG); - sst += 8; + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -5002,7 +5101,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -5047,8 +5146,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -5057,13 +5156,14 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5210,9 +5310,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -5327,8 +5427,8 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } @@ -5621,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -5659,8 +5759,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -5799,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4; From 21e6e72a2fb4a540a002855e568cb21d2b6f08c6 Mon Sep 17 00:00:00 2001 From: ai-fonsi Date: Wed, 8 Oct 2025 20:21:46 +0200 Subject: [PATCH 039/104] Disable CUDA host buffers on integrated GPUs (llama/16308) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 26e72bbc..fb691528 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -231,7 +231,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].integrated = prop.integrated; + info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034) info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; From 7df6766b63a8d38d6f73b46d1d6426f1a0ef2bcc Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Thu, 9 Oct 2025 15:25:11 +0800 Subject: [PATCH 040/104] refactor soft_max, add soft_max_back (llama/16472) * refactor to support soft_max_ext * fix error and support soft_max_back * rm unused functions * fix format issue --------- Co-authored-by: Zhang Jianyu --- ggml/src/ggml-sycl/common.hpp | 86 ++++- ggml/src/ggml-sycl/dpct/helper.hpp | 20 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 25 +- ggml/src/ggml-sycl/softmax.cpp | 491 +++++++++++++++++++---------- ggml/src/ggml-sycl/softmax.hpp | 4 + 5 files changed, 437 insertions(+), 189 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 4e7449d0..d66d7ade 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -197,6 +197,7 @@ struct sycl_device_info { int cc; // compute capability // int nsm; // number of streaming multiprocessors // size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t total_vram; //sycl_hw_info hw_info; \\ device id and aarch, currently not used @@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:98: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); } return x; @@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } +template +static __dpct_inline__ int warp_reduce_sum(int x) { + return sycl::reduce_over_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); +} + +template +static __dpct_inline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width); + } + return x; +} + +template +static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a.x() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset, + width); + a.y() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset, + width); + } + return a; +} + +template +static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a = a + dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset, + width); + } + return a; +} + +static constexpr int ggml_sycl_get_physical_warp_size() { + // todo: for old iGPU + dGPU case, need to be changed. + return WARP_SIZE; +} + +template +static __dpct_inline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width)); + } + return x; +} + static __dpct_inline__ float warp_reduce_max(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:97: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x = sycl::fmax(x, dpct::permute_sub_group_by_xor( item_ct1.get_sub_group(), x, mask)); } @@ -558,4 +602,18 @@ struct scope_op_debug_print { std::string_view func_suffix; }; +static __dpct_inline__ float get_alibi_slope(const float max_bias, + const uint32_t h, + const uint32_t n_head_log2, + const float m0, + const float m1) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return dpct::pow(base, exph); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index d538965b..f93cfa70 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -277,6 +277,26 @@ namespace dpct } // namespace detail + // COPY from DPCT head files + /// dim3 is used to store 3 component dimensions. + class dim3 { + public: + unsigned x, y, z; + + constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1) + : x(x), y(y), z(z) {} + + dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {} + + operator sycl::range<3>() const { return sycl::range<3>(z, y, x); } + }; // namespace dim3 + + inline dim3 operator*(const dim3 &a, const dim3 &b) { + return dim3{a.x * b.x, a.y * b.y, a.z * b.z}; + } + // COPY from DPCT head files + + /// Pitched 2D/3D memory data. class pitched_data { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 4ac919ea..e4cc3c8e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -87,6 +87,7 @@ static ggml_sycl_device_info ggml_sycl_init() { 100 * prop.get_major_version() + 10 * prop.get_minor_version(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].smpbo = prop.get_local_mem_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -3741,6 +3742,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_sycl_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_sycl_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; @@ -3778,6 +3782,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return true; } catch (sycl::exception & e) { std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::cerr << "Error OP "<op)<< std::endl; std::exit(1); } @@ -4386,19 +4391,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_SOFT_MAX: - // TODO: support batching - if (op->src[0]->ne[3] != 1) { - return false; - } - // TODO: support attention sinks [TAG_ATTN_SINKS] - if (op->src[2]) { - return false; - } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_DIAG_MASK_INF: + return true; + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 52fcf4b3..83b7c71b 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,37 +1,94 @@ #include "softmax.hpp" +#include +#include +#include -template -static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; - const int tid = item_ct1.get_local_id(2); - const int rowx = item_ct1.get_group(2); - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension +template static __dpct_inline__ float t2f32(T val) { + return (float) val; +} - const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; +template <> float __dpct_inline__ t2f32(sycl::half val) { + return sycl::vec(val) + .convert()[0]; +} - const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static void soft_max_f32(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params p, + uint8_t * dpct_local) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + const int block_size = block_size_template == 0 + ? item_ct1.get_local_range(2) + : block_size_template; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; size_t nreduce = nwarps / WARP_SIZE; - float slope = 1.0f; - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = rowx/nrows_y; // head index + const int tid = item_ct1.get_local_id(2); - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int64_t i03 = item_ct1.get_group(0); + const int64_t i02 = item_ct1.get_group(1); + const int64_t i01 = item_ct1.get_group(2); - slope = sycl::pow(base, float(exp)); - } + //TODO: noncontigous inputs/outputs + const int rowx = item_ct1.get_group(2) + + item_ct1.get_group(1) * item_ct1.get_group_range(2) + + item_ct1.get_group(0) * item_ct1.get_group_range(2) * + item_ct1.get_group_range(1); - float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; - float max_val = -INFINITY; + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + float * buf_iw = (float *) dpct_local; + + // shared memory buffer to cache values between iterations: + float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst; + float max_val = sinks ? sinks[i02] : -INFINITY; +#pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; - max_val = sycl::max(max_val, val); + max_val = sycl::max(max_val, val); } - // find the max value in the block - max_val = warp_reduce_max(max_val, item_ct1); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; - for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = -INFINITY; - } + buf_iw[lane_id] = -INFINITY; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } - item_ct1.barrier(sycl::access::fence_space::local_space); - max_val = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) { - max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); - } - max_val = warp_reduce_max(max_val, item_ct1); - } + item_ct1.barrier(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + float tmp = 0.0f; // partial sum - float tmp = 0.f; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + + if (ncols_template == 0 && col >= ncols) { break; } @@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int tmp += val; vals[col] = val; } - // find the sum of exps in the block - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = 0.f; + buf_iw[lane_id + i * WARP_SIZE] = 0.f; } } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - tmp += buf[lane_id + i * WARP_SIZE]; + tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); } - - const float inv_sum = 1.f / tmp; + if (sinks) { + tmp += sycl::native::exp(sinks[i02] - max_val); + } + const float inv_sum = 1.0f / tmp; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int return; } - const int idst = rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, + const int ncols, const float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } -template -static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, queue_ptr stream) { +template +static void launch_soft_max_kernels(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & p, + dpct::queue_ptr stream, + dpct::dim3 block_dims, + dpct::dim3 block_nums, + size_t nbytes_shared) +{ + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + if (p.ncols == ncols) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, - nrows_y, scale, max_bias, m0, - m1, n_head_log2, item_ct1, - get_pointer(local_buf_acc)); - }); + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); }); } -template -static void soft_max_f32_sycl(const float * x, const T * mask, - float * dst, const int ncols_x, const int nrows_x, - const int nrows_y, const float scale, const float max_bias, - queue_ptr stream, int device) { +template +static void soft_max_f32_sycl(const float *x, const T *mask, + const float *sinks, float *dst, + const soft_max_params ¶ms, + dpct::queue_ptr stream, int device) { int nth = WARP_SIZE; int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; - const sycl::range<3> block_dims(1, 1, nth); - const sycl::range<3> block_nums(1, 1, nrows_x); - const size_t n_val_tmp = nth / WARP_SIZE; - const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp); + const dpct::dim3 block_dims(nth, 1, 1); + const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = + (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const int id = get_current_device_id(); + const size_t smpbo = ggml_sycl_info().devices[id].smpbo; - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const size_t local_mem_size = stream->get_device().get_info(); - if (n_local_scratch*sizeof(float) < local_mem_size) { - if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - return; - } - switch (ncols_x) { - case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - } + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>( + x, mask, sinks, dst, params, stream, block_dims, block_nums, + nbytes_shared); } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, WARP_SIZE, stream); + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared_low), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_f32( + x, mask, sinks, dst, params, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); } } +static void soft_max_back_f32_sycl(const float * grad, + const float * dstf, + float * dst, + const int ncols, + const int nrows, + const float scale, + dpct::queue_ptr stream) { + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums(nrows, 1, 1); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_back_f32(grad, dstf, dst, ncols, scale); + GGML_UNUSED(item_ct1); + }); +} + void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional + // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows_x = ggml_nrows(dst->src[0]); - const int64_t nrows_y = dst->src[0]->ne[1]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - ggml_sycl_set_device(ctx.device); - dpct::queue_ptr main_stream = ctx.stream(); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; - if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { - const sycl::half * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, - main_stream, ctx.device); - } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { - const float * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d, + (const float *)src2_d, dst_d, params, stream, + ctx.device); } else { - /* mask unavailable */ - soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d, + dst_d, params, stream, ctx.device); } } + +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index 2cf8582e..23f1e5a9 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,6 +15,10 @@ #include "common.hpp" +#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 + void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_SOFTMAX_HPP From c8b2c56fd27eb9b66322e5700c84664c01cb6a26 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Thu, 9 Oct 2025 09:29:17 +0200 Subject: [PATCH 041/104] kleidiai: kernel interface refactoring (llama/16460) --- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 305 ++++++++++++++++-------- ggml/src/ggml-cpu/kleidiai/kernels.h | 76 +++--- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 124 ++++------ 3 files changed, 292 insertions(+), 213 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 7ba65912..3eaa5e3f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -29,6 +29,108 @@ #define NELEMS(x) sizeof(x) / sizeof(*x) +template +static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { + return Fn(a, b, c); +} + +template +static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) { + return Fn(a, b); +} + +template +static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, bl, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, mr, kr, sr); +} + +template +static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast(lhs), lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(n, k, nr, kr, bl); +} + +template +static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(n, k); +} + +template +static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(k, nr, kr, bl); +} + +template +static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(k); +} + +template +static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, bl, + static_cast(rhs), + static_cast(bias), + rhs_packed, extra_bytes, + static_cast(params)); +} + +template +static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, + size_t rhs_stride, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params); +} + static const size_t INT4_PER_BYTE = 2; static const size_t INT4_BITS = 4; static const int Q4_0_ZERO_POINT = 8; @@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, + /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* SME GEMV */ /* .kern_info = */ { @@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_fn10, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* SME GEMV */ /* .kern_info = */ { @@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ nullptr, + /* .get_rhs_packed_offset_ex = */ nullptr, + /* .run_kernel_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .packed_stride = */ NULL, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .to_float = */ NULL, + /* .packed_stride = */ nullptr, + /* .to_float = */ nullptr, + /* .packed_size_ex = */ &rhs_ps_fn2, + /* .packed_stride_ex = */ &rhs_stride_fn1, + /* .pack_func_ex = */ &rhs_pack_fn13, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && @@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c break; } } +#endif } return kernel; @@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#endif return kernels; } diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ad6ad6f..a84795a6 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,8 +4,6 @@ #pragma once -#include -#include #include "ggml.h" enum cpu_feature { @@ -15,6 +13,7 @@ enum cpu_feature { CPU_FEATURE_SVE = 4, CPU_FEATURE_SME = 8 }; + inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { lhs = static_cast(lhs | rhs); return lhs; @@ -30,63 +29,52 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - std::variant< - std::function, - std::function - > get_lhs_offset; - std::variant< - std::function, - std::function - > get_rhs_packed_offset; + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - std::variant< - std::function, - std::function - > run_kernel; + + size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl); + + size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl); + + void (*run_kernel_ex)( + size_t m, size_t n, size_t k, size_t bl, + const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max); }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - std::variant< - std::function, - std::function - > get_packed_offset; - std::variant< - std::function, - std::function - > packed_size; - std::variant< - std::function, - std::function - > pack_func; + + size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed); }; struct rhs_packing_info { - std::variant< - std::function, - std::function - > packed_size; size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl); - std::variant< - std::function, - std::function - > pack_func; - void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, - size_t kr, size_t bl, size_t num_bytes_multiplier); + + void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, + size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl, + size_t num_bytes_multiplier); + + size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + + size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl); + + void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params); }; struct ggml_kleidiai_kernels { - kernel_info gemm; + kernel_info gemm; lhs_packing_info gemm_lhs_info; - kernel_info gemv; + kernel_info gemv; lhs_packing_info gemv_lhs_info; rhs_packing_info rhs_info; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 44691e5d..8b3df7d7 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #if defined(__linux__) #include #include @@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } -template -constexpr bool variant_any_invocable_impl(std::index_sequence) { - using V = std::remove_reference_t; - return (std::is_invocable_r_v< - Ret, - std::variant_alternative_t, - Args...> || ...); -} - -template -constexpr bool variant_any_invocable_v = - variant_any_invocable_impl( - std::make_index_sequence< - std::variant_size_v>>{}); - -template -static inline Ret variant_call(Variant && var, Args&&... args) { - static_assert(variant_any_invocable_v, Ret, Args...>, - "No alternative in Variant is invocable with the provided arguments and return type."); - - return std::visit( - [&](auto && f) -> Ret { - using F = std::decay_t; - if constexpr (std::is_invocable_r_v) { - return std::invoke(std::forward(f), std::forward(args)...); - } else { - GGML_ABORT("Invalid function type in variant_call"); - GGML_UNREACHABLE(); - } - }, - std::forward(var) - ); -} - namespace ggml::cpu::kleidiai { static size_t round_down(size_t x, size_t y) { @@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = op->src[1]->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; @@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = kernel->get_sr(); if (kernels->rhs_type == GGML_TYPE_Q4_0) { - size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); + if (!lhs_info->packed_size_ex) return false; + size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { + if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; const int64_t lhs_batch_size0 = op->src[1]->ne[2]; const int64_t rhs_batch_size0 = op->src[0]->ne[2]; const int64_t r = lhs_batch_size0 / rhs_batch_size0; - size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + - variant_call(kernels->rhs_info.packed_size, n, k) + + size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + + kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + k * n * sizeof(float) + n * sizeof(float); } else { - GGML_ASSERT(false); + return false; } return true; @@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!kernels->rhs_info.pack_func_ex || + !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) { + return false; + } const int nth = params->nth; const int ith = params->ith; @@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t kr = (int64_t) kernel->get_kr(); const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); - const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); - const size_t bias_size = (size_t)n * sizeof(float); + const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr); + const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; // Base packed offset (aligned) and per-row stride in bytes - const size_t base_packed_off = variant_call( - lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t next_block_off = variant_call( - lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); + const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr); const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; int64_t remaining = m_count; @@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - variant_call(lhs_info->pack_func, - (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, - /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); + lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); cur += take; remaining -= take; @@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { reinterpret_cast(rhs_batch_base), rhs_stride); - variant_call(kernels->rhs_info.pack_func, - /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, - /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), - rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); + kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); @@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; // LHS packed base at row 0 (consistent with packing above) - const size_t lhs_packed_offset0 = variant_call( - lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); - const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); + const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, - (size_t)m, (size_t)n_to_process, (size_t)k, - lhs_ptr, rhs_ptr, - dst_ptr, dst_stride, sizeof(float), - -FLT_MAX, FLT_MAX); + kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } } @@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || + !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + return false; + } const int ith = params->ith; const int nth_raw = params->nth; @@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits { // Transform LHS const size_t src_stride = src1->nb[1]; const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer + lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); } ggml_barrier(params->threadpool); // Perform the operation const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); if (n_to_process > 0) { - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } @@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); + if (!ctx.kernels) { + return false; + } const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; kernel_info * kernel = &ctx.kernels->gemm; + if (!rhs_info->to_float || !kernel->get_nr) { + return false; + } const int64_t nc = ne00; const int64_t nr = ggml_nelements(src1); @@ -480,7 +458,7 @@ public: struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); + ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms); return 0; GGML_UNUSED(data_size); @@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_ const size_t nr = ctx.kernels->gemm.get_nr(); const size_t kr = ctx.kernels->gemm.get_kr(); - return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); + return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0); GGML_UNUSED(buft); } From b9eac9419c03120b2dc319bbd16357443f8a592c Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 9 Oct 2025 15:50:25 +0800 Subject: [PATCH 042/104] CANN: Improve ACL graph matching (llama/16166) * CANN: improve ACL graph matching Record `ne` and `nb` information for src tensors and include them in the graph matching check. This enhances the robustness of ACL graph matching by preventing incorrect matches when src tensors share the same data address but differ in shape or stride. * CANN: add op_params match --- ggml/src/ggml-cann/common.h | 9 +++++- ggml/src/ggml-cann/ggml-cann.cpp | 48 ++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index b707b843..debbcadc 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -341,11 +341,18 @@ private: #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { + // dst tensor void * node_address; - ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; + + // src tensor void * src_address[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + + // op + ggml_op node_op; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b51b554e..ad1adba6 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties( std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); for (int src = 0; src < GGML_MAX_SRC; ++src) { - prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; + if (node->src[src]) { + prop.src_address[src] = node->src[src]->data; + std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); + std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); + } else { + prop.src_address[src] = nullptr; + std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); + std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); + } } memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); @@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties( * @param graph_node_properties The stored properties of a CANN graph node. * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. */ -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { +static bool ggml_graph_node_has_matching_properties( + ggml_tensor * node, + ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_VIEW) { + node->op != GGML_OP_VIEW) { return false; } + if (node->op != graph_node_properties->node_op) { return false; } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != graph_node_properties->ne[i]) { return false; @@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; + if (node->src[i]) { + if (node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_VIEW) { + return false; + } + + for (int d = 0; d < GGML_MAX_DIMS; d++) { + if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) { + return false; + } + if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) { + return false; + } + } + } else { + if (graph_node_properties->src_address[i] != nullptr) { + return false; + } } } - if (node->op == GGML_OP_SCALE && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; + + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } return true; } From d83fef35dfc48a2a5d35f46b4a999b768e37c32e Mon Sep 17 00:00:00 2001 From: duduta Date: Thu, 9 Oct 2025 22:11:15 +0300 Subject: [PATCH 043/104] cpu : optimize the ggml NORM operation (llama/15953) * ggml-cpu: optimize norm operation to use intrinsics or Accelerate rename function add endif macro comment Co-authored-by: Georgi Gerganov Co-authored-by: Aaron Teo * implement s390x SIMD suggested by @taronaeo * add TODO comment * tidy up spaces --------- Co-authored-by: Georgi Gerganov Co-authored-by: Aaron Teo --- ggml/src/ggml-cpu/ops.cpp | 24 ++++++-------- ggml/src/ggml-cpu/vec.cpp | 66 +++++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/vec.h | 1 + 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8e1a2de1..1c43865f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(eps >= 0.0f); - // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, x); float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + float variance = 0; - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } +#ifdef GGML_USE_ACCELERATE + mean = -mean; + vDSP_vsadd(x, 1, &mean, y, 1, ne00); + vDSP_measqv(y, 1, &variance, ne00); +#else + variance = ggml_vec_cvar_f32(ne00, y, x, mean); +#endif //GGML_USE_ACCELERATE - float variance = sum2/ne00; const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); } } diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 437192d5..b8e37052 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * } } +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { + int i = 0; + ggml_float sum = 0; +// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE +// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(mean)); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val,val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(mean)); + vst1q_f32(y + i, val); + val = vmulq_f32(val, val); + sum += (ggml_float)vaddvq_f32(val); + } +#elif defined(__VXE__) || defined(__VXE2__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean)); + vec_xst(val, 0, y + i); + val = vec_mul(val, val); + sum += (ggml_float)vec_hsum_f32x4(val); + } +#endif + for (; i < n; ++i) { + float val = x[i] - mean; + val *= val; + sum += (ggml_float)val; + y[i] = val; + } + return sum/n; +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index f95ca94e..2751359c 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_silu_f32(const int n, float * y, const float * x); +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean ) ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); From d8f1aa4e1d7ca6fd46b44683289a5b850b4bdc6a Mon Sep 17 00:00:00 2001 From: Prajwal B Mehendarkar Date: Fri, 10 Oct 2025 13:45:46 +0530 Subject: [PATCH 044/104] cmake : Dont define XOPENSOURCE on AIX (llama/16481) --- ggml/src/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c8f3d859..892c2331 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -145,6 +145,9 @@ endif() # which was introduced in POSIX.1-2008, forcing us to go higher if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") add_compile_definitions(_XOPEN_SOURCE=700) +elseif (CMAKE_SYSTEM_NAME MATCHES "AIX") + # Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default, + # in order to define _SC_PHYS_PAGES. else() add_compile_definitions(_XOPEN_SOURCE=600) endif() From 1cc342427b77e3ee3a6693d4e86da66111f326fe Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Sat, 11 Oct 2025 04:02:26 -0700 Subject: [PATCH 045/104] cuda : avoid initializing unused devices (llama/16510) --- ggml/src/ggml-cuda/ggml-cuda.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fb691528..856e9de2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3867,7 +3867,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { dev_ctx->device = i; dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); - ggml_cuda_set_device(i); cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; From d201705e71d21cbcd001e905a365d6f25fe3bbb2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 Oct 2025 16:54:10 +0300 Subject: [PATCH 046/104] metal : fix mul-mm condition + fix mul-mv permuted kernels (llama/16494) --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 66 +++++++++++++++----------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 1137e210..5f937044 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 45d91def..ddc28504 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7517,6 +7516,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7524,24 +7526,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7731,6 +7736,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7738,20 +7746,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; From ed6a3063ec0df6af61327f61d67449418ef42b2b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Oct 2025 08:36:34 +0300 Subject: [PATCH 047/104] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5e09de49..b84ddf48 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -72632094336524a9c809e129e8b1c52154543a5a +fcc2a5c0cfd81ee0517ee42f1acdc371ec92d598 From ff4c1a5a53887829d1eed250f554c021bfcd170b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Oct 2025 08:37:14 +0300 Subject: [PATCH 048/104] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 62 +++ examples/talk-llama/llama-arch.h | 15 + examples/talk-llama/llama-chat.cpp | 2 +- examples/talk-llama/llama-context.cpp | 6 + examples/talk-llama/llama-graph.cpp | 17 + examples/talk-llama/llama-graph.h | 8 + examples/talk-llama/llama-hparams.cpp | 6 +- examples/talk-llama/llama-hparams.h | 14 +- examples/talk-llama/llama-kv-cache-iswa.cpp | 4 +- examples/talk-llama/llama-kv-cache.cpp | 7 +- examples/talk-llama/llama-memory-hybrid.cpp | 20 +- .../talk-llama/llama-memory-recurrent.cpp | 14 +- examples/talk-llama/llama-model-loader.cpp | 2 + examples/talk-llama/llama-model.cpp | 379 ++++++++++++++++-- examples/talk-llama/llama-model.h | 13 + examples/talk-llama/llama-sampling.cpp | 5 + examples/talk-llama/llama-vocab.cpp | 6 + examples/talk-llama/llama-vocab.h | 81 ++-- examples/talk-llama/llama.h | 8 + 19 files changed, 565 insertions(+), 104 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 4e8d54c4..869e4dcc 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -93,12 +93,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_LLADA_MOE, "llada-moe" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -217,6 +219,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -256,6 +263,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1064,6 +1076,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2098,6 +2112,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2119,6 +2159,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2229,6 +2288,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2468,6 +2529,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index b5c6f3d7..c3ae7165 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -97,12 +97,14 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_LLADA_MOE, LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -260,10 +262,21 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -271,6 +284,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 66e6c6a3..956c4e08 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index d8a8b5e6..e7526e7d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model( return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 90cd885a..a24853c6 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 34b984af..dc84b794 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c04ac58f..db65d69e 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const { } bool llama_hparams::is_recurrent(uint32_t il) const { - return recurrent_layer_arr[il]; + if (il < n_layer) { + return recurrent_layer_arr[il]; + } + + GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); } uint32_t llama_hparams::n_pos_per_embd() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0fe4b569..4e7f73ec 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -42,7 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -169,6 +169,18 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 827302e6..facba1d0 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const { } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id } void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 816f2d5d..736693e1 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index abf65248..dfb8439e 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -175,17 +177,17 @@ std::map llama_memory_hybrid::memory_breakdo } void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - GGML_UNUSED(flags); - - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - GGML_UNUSED(flags); - - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 44645fcd..d67f5a5f 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } @@ -379,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -856,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 8182a9ad..aa3a65f8 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -465,6 +465,8 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ffd9286e..36d495d6 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -310,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -331,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -512,9 +515,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1084,7 +1091,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { @@ -1207,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); switch (hparams.n_layer) { case 24: type = LLM_TYPE_0_3B; break; @@ -1985,14 +2005,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2029,6 +2064,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2062,7 +2110,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3392,17 +3440,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters const uint32_t d_conv = hparams.ssm_d_conv; const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3436,6 +3484,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; @@ -3444,8 +3494,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } @@ -3645,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4825,11 +4880,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -5787,6 +5844,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5798,11 +5856,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -5907,6 +5977,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); } } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6241,7 +6353,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -7776,6 +7888,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -7785,6 +7899,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_k_norm) { + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -8167,6 +8283,9 @@ struct llm_build_mpt : public llm_graph_context { // Q/K Layernorm if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -11751,6 +11870,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14705,6 +14825,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14736,7 +14857,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; @@ -16192,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { @@ -17607,6 +17728,7 @@ private: const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; @@ -18523,6 +18645,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18537,7 +18661,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18552,23 +18685,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19088,6 +19230,141 @@ struct llm_build_grovemoe : public llm_graph_context { } }; +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -19603,6 +19880,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19618,6 +19896,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -19625,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } @@ -19649,6 +19937,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -19820,10 +20109,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -19934,6 +20225,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d73ce969..7f48662f 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -107,6 +107,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -380,6 +381,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -431,6 +438,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index 2186f827..55d2e355 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -2541,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (n_non_eog == 0) { cur_p->size = 1; cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } cur_p->data[0].logit = 1.0f; + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + return; } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index da938af0..7fffd171 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -1961,6 +1962,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "trillion") { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || tokenizer_pre == "llada-moe") { @@ -2166,6 +2171,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end|>" || t.first == "" || t.first == "<|endoftext|>" + || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 0d2f28c3..5e468675 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,46 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 452d9ec5..a0a660bf 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -296,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -543,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -791,8 +795,12 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( From ea174c62bc8bac36cec499c5be7db75d91cbd129 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Oct 2025 08:47:48 +0300 Subject: [PATCH 049/104] bench : update [no ci] --- scripts/bench-all-gg.txt | 100 +++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index d1cdaf9a..cf3d26fb 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -111,61 +111,61 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 0 | 8.63 | 1.09 | 0.27 | 0.01 | b57b9d3a | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.04 | 1.06 | 0.28 | 0.01 | b57b9d3a | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 8.98 | 1.06 | 0.28 | 0.01 | b57b9d3a | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 8.69 | 1.06 | 0.27 | 0.01 | b57b9d3a | -| M2 ULTRA | METAL | base | 1 | 0 | 15.39 | 1.54 | 0.43 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.50 | 1.50 | 0.42 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.45 | 1.49 | 0.43 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 15.62 | 1.51 | 0.42 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | small | 1 | 0 | 45.99 | 2.99 | 0.90 | 0.05 | b57b9d3a | -| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 50.65 | 2.98 | 0.92 | 0.06 | b57b9d3a | -| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 50.74 | 2.96 | 0.92 | 0.06 | b57b9d3a | -| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 47.16 | 2.83 | 0.89 | 0.06 | b57b9d3a | -| M2 ULTRA | METAL | medium | 1 | 0 | 132.78 | 6.46 | 2.02 | 0.13 | b57b9d3a | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 149.35 | 6.11 | 2.09 | 0.14 | b57b9d3a | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 149.11 | 6.09 | 2.11 | 0.14 | b57b9d3a | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 137.37 | 6.05 | 2.03 | 0.13 | b57b9d3a | -| M2 ULTRA | METAL | medium-dis | 1 | 0 | 121.60 | 0.90 | 0.25 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | large-v2 | 1 | 0 | 231.19 | 9.40 | 3.10 | 0.22 | b57b9d3a | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 265.90 | 8.98 | 3.11 | 0.25 | b57b9d3a | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 265.18 | 8.92 | 3.13 | 0.25 | b57b9d3a | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 240.23 | 9.06 | 2.98 | 0.23 | b57b9d3a | -| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 210.25 | 0.99 | 0.28 | 0.02 | b57b9d3a | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 211.72 | 1.52 | 0.46 | 0.03 | b57b9d3a | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 242.17 | 1.40 | 0.47 | 0.04 | b57b9d3a | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 219.75 | 1.40 | 0.45 | 0.04 | b57b9d3a | +| M2 ULTRA | METAL | tiny | 1 | 0 | 8.82 | 1.14 | 0.28 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.28 | 1.11 | 0.29 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.28 | 1.11 | 0.29 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 8.94 | 1.12 | 0.28 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | base | 1 | 0 | 15.84 | 1.60 | 0.43 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 17.62 | 1.61 | 0.47 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 17.00 | 1.57 | 0.45 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.19 | 1.56 | 0.43 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | small | 1 | 0 | 47.72 | 3.12 | 0.92 | 0.06 | 2ad7a695 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 52.59 | 3.13 | 0.94 | 0.06 | 2ad7a695 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 52.50 | 3.09 | 0.94 | 0.06 | 2ad7a695 | +| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 48.92 | 2.92 | 0.91 | 0.06 | 2ad7a695 | +| M2 ULTRA | METAL | medium | 1 | 0 | 136.84 | 6.64 | 2.06 | 0.13 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 152.83 | 6.32 | 2.13 | 0.14 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 153.27 | 6.30 | 2.14 | 0.14 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 142.05 | 6.14 | 2.08 | 0.13 | 2ad7a695 | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 123.80 | 0.91 | 0.25 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 238.97 | 9.69 | 3.13 | 0.22 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 273.72 | 9.31 | 3.17 | 0.25 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 273.42 | 9.26 | 3.18 | 0.25 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 247.80 | 9.33 | 3.04 | 0.23 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 213.83 | 1.00 | 0.28 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 215.47 | 1.54 | 0.47 | 0.03 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 246.32 | 1.44 | 0.47 | 0.04 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 223.43 | 1.44 | 0.45 | 0.04 | 2ad7a695 | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 1 | 6.28 | 0.96 | 0.22 | 0.01 | a77d11d9 | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.69 | 0.92 | 0.22 | 0.01 | a77d11d9 | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.67 | 0.91 | 0.22 | 0.01 | a77d11d9 | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.34 | 0.92 | 0.21 | 0.01 | a77d11d9 | -| M2 ULTRA | METAL | base | 1 | 1 | 10.77 | 1.30 | 0.32 | 0.02 | a77d11d9 | -| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.84 | 1.23 | 0.33 | 0.02 | a77d11d9 | -| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 11.95 | 1.24 | 0.33 | 0.02 | a77d11d9 | -| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.14 | 1.23 | 0.32 | 0.02 | a77d11d9 | -| M2 ULTRA | METAL | small | 1 | 1 | 32.12 | 2.43 | 0.65 | 0.04 | a77d11d9 | -| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.95 | 2.42 | 0.68 | 0.04 | a77d11d9 | -| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 37.40 | 2.42 | 0.68 | 0.04 | a77d11d9 | -| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 33.48 | 2.30 | 0.65 | 0.04 | a77d11d9 | -| M2 ULTRA | METAL | medium | 1 | 1 | 89.28 | 5.05 | 1.46 | 0.09 | a77d11d9 | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 105.24 | 4.89 | 1.48 | 0.11 | a77d11d9 | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 105.28 | 4.98 | 1.49 | 0.11 | a77d11d9 | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 93.61 | 4.89 | 1.43 | 0.10 | a77d11d9 | -| M2 ULTRA | METAL | medium-dis | 1 | 1 | 78.44 | 0.81 | 0.20 | 0.01 | a77d11d9 | -| M2 ULTRA | METAL | large-v2 | 1 | 1 | 165.69 | 7.50 | 2.16 | 0.17 | a77d11d9 | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 199.40 | 7.37 | 2.18 | 0.20 | a77d11d9 | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 199.29 | 7.37 | 2.21 | 0.20 | a77d11d9 | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 174.60 | 6.87 | 2.16 | 0.18 | a77d11d9 | -| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 145.80 | 0.90 | 0.22 | 0.02 | a77d11d9 | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 146.98 | 1.31 | 0.34 | 0.03 | a77d11d9 | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 176.77 | 1.19 | 0.35 | 0.03 | a77d11d9 | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 154.73 | 1.20 | 0.33 | 0.03 | a77d11d9 | +| M2 ULTRA | METAL | tiny | 1 | 1 | 6.13 | 0.95 | 0.22 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.56 | 0.91 | 0.22 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.59 | 0.92 | 0.23 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.23 | 0.93 | 0.22 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | base | 1 | 1 | 10.73 | 1.31 | 0.33 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.89 | 1.25 | 0.34 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 11.83 | 1.24 | 0.34 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.03 | 1.25 | 0.32 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | small | 1 | 1 | 32.05 | 2.42 | 0.65 | 0.04 | 2ad7a695 | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.73 | 2.41 | 0.67 | 0.04 | 2ad7a695 | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.77 | 2.41 | 0.68 | 0.04 | 2ad7a695 | +| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 33.33 | 2.28 | 0.65 | 0.04 | 2ad7a695 | +| M2 ULTRA | METAL | medium | 1 | 1 | 88.19 | 5.10 | 1.47 | 0.09 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 104.23 | 4.90 | 1.48 | 0.10 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 104.19 | 5.02 | 1.51 | 0.10 | 2ad7a695 | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 92.41 | 4.96 | 1.44 | 0.09 | 2ad7a695 | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 76.97 | 0.79 | 0.20 | 0.01 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 169.61 | 7.48 | 2.14 | 0.17 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 203.04 | 7.35 | 2.18 | 0.20 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 202.91 | 7.32 | 2.20 | 0.20 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 178.30 | 6.86 | 2.12 | 0.18 | 2ad7a695 | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 146.47 | 0.89 | 0.22 | 0.02 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 147.86 | 1.30 | 0.34 | 0.03 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 177.75 | 1.17 | 0.35 | 0.03 | 2ad7a695 | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 155.51 | 1.18 | 0.33 | 0.03 | 2ad7a695 | ## M4 Max From a91dd3be72f70dd1b3cb6e252f35fa17b93f596c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Oct 2025 11:17:59 +0300 Subject: [PATCH 050/104] release : v1.8.1 --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2df1dbaa..91b9d0a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.0) +project("whisper.cpp" VERSION 1.8.1) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index 87525c66..f197c934 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.0](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.0) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.8.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 0cfd6504..ae601157 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.0", + "version": "1.8.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From b5fb9b9f58ea65d0e367d1183dd328283aecee66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 11 Oct 2025 20:54:32 +0200 Subject: [PATCH 051/104] CUDA: faster tile FA, add oob checks, more HSs (llama/16492) --- ggml/src/ggml-cuda/CMakeLists.txt | 2 + ggml/src/ggml-cuda/common.cuh | 7 +- ggml/src/ggml-cuda/fattn-common.cuh | 9 +- ggml/src/ggml-cuda/fattn-tile.cu | 771 +---------- ggml/src/ggml-cuda/fattn-tile.cuh | 1213 +++++++++++++++++ ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 2 + ggml/src/ggml-cuda/fattn.cu | 76 +- .../fattn-tile-instance-dkq112-dv112.cu | 5 + .../fattn-tile-instance-dkq128-dv128.cu | 5 + .../fattn-tile-instance-dkq256-dv256.cu | 5 + .../fattn-tile-instance-dkq40-dv40.cu | 5 + .../fattn-tile-instance-dkq576-dv512.cu | 5 + .../fattn-tile-instance-dkq64-dv64.cu | 5 + .../fattn-tile-instance-dkq80-dv80.cu | 5 + .../fattn-tile-instance-dkq96-dv96.cu | 5 + .../template-instances/generate_cu_files.py | 18 +- ggml/src/ggml-hip/CMakeLists.txt | 2 + ggml/src/ggml-musa/CMakeLists.txt | 2 + 18 files changed, 1358 insertions(+), 784 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index bdcefe7b..30247751 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d51abbea..e0abde54 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -245,7 +245,8 @@ static bool fp16_available(const int cc) { } static bool fast_fp16_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, } // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { if constexpr (alignment != 0) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 33d2f0f4..bc0c2523 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -793,8 +793,6 @@ void launch_fattn( GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -878,7 +876,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -916,8 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -946,7 +943,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 68de623d..3a5806d9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,756 +1,45 @@ #include "common.cuh" -#include "fattn-common.cuh" #include "fattn-tile.cuh" #include "fattn-wmma-f16.cuh" -// kq_stride == number of KQ rows to process per iteration -// kq_nbatch == number of K columns to load in parallel for KQ calculation - -static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { - if (GGML_CUDA_CC_IS_AMD(cc)) { - if (GGML_CUDA_CC_IS_RDNA(cc)) { - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - if (fast_fp16_available(cc)) { - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - GGML_UNUSED(warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP -#ifdef RDNA - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // RDNA -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else - switch (D) { - case 64: - return 64; - case 128: - return 128; - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { - return 256; - GGML_UNUSED_VARS(cc, ncols); -} - -static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { - return 256; - GGML_UNUSED(ncols); -} - -static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { -#ifdef RDNA - return 3; -#else - return ncols <= 16 ? 3 : 2; -#endif // RDNA - GGML_UNUSED(ncols); -} - -template // D == head size -__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) -static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef GGML_USE_WMMA_FATTN - NO_DEVICE_CODE; - return; -#endif // GGML_USE_WMMA_FATTN - - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - constexpr int warp_size = 32; - constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; - constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); - static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); - constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); - static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); - constexpr int cpy_ne = cpy_nb / 4; - - constexpr int cpw = ncols/nwarps; // cols per warp - - // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. - // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. -#ifdef FAST_FP16_AVAILABLE - constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; - - __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ half2 Q_tmp[ncols][D/2]; - __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. - half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#else - constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; - - __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ float Q_tmp[ncols][D]; - __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. - float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#endif // FAST_FP16_AVAILABLE - static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); - - float KQ_max[cpw]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - KQ_max[j0/nwarps] = -FLT_MAX/2.0f; - } - float KQ_sum[cpw] = {0.0f}; - - // Load Q data, convert to FP16 if fast. -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - const int j = j0 + threadIdx.y*cpw; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - float tmp_f[cpy_ne_D] = {0.0f}; - if (ic0 + j < ne01) { - ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); - } - -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp_f[i1] *= scale; - } - -#ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { - tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); - } - ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); -#else - ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); -#endif // FAST_FP16_AVAILABLE - } - } - - __syncthreads(); - - // Main loop over KV cache: - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float KQ_max_new[cpw]; -#pragma unroll - for (int j = 0; j < cpw; ++j) { - KQ_max_new[j] = KQ_max[j]; - } - - float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. - - // KQ = K @ Q matrix multiplication: -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], - &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); - } -#else - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { - half2 tmp_h2[cpy_ne_kqnb/2]; - ggml_cuda_memcpy_1( - tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); - - float2 tmp_f2[cpy_ne_kqnb/2]; -#pragma unroll - for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { - tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); - } - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { - half2 K_k[kq_stride/warp_size][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; -#else -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { - float K_k[kq_stride/warp_size][cpy_ne]; - float Q_k[cpw][cpy_ne]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { -#pragma unroll - for (int k = 0; k < cpy_ne; ++k) { - ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); - } - } - } - } - - if (k_KQ_0 + kq_nbatch < D) { - __syncthreads(); // Sync not needed on last iteration. - } - } - - // Apply logit softcap, mask, update KQ_max: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - - if (use_logit_softcap) { - KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - - KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - } - - __syncthreads(); - - // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { -#ifdef FAST_FP16_AVAILABLE - half tmp[kq_stride/warp_size][softmax_iter_j]; -#else - float tmp[kq_stride/warp_size][softmax_iter_j]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); - const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); - KQ_max[j0+j1] = KQ_max_new[j0+j1]; - - float KQ_sum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); - KQ_sum_add += val; - tmp[i0/warp_size][j1] = val; - } - KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; - VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - ggml_cuda_memcpy_1( - KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); - } - } - - // VKQ = V @ KQ matrix multiplication: - constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. - static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); -#pragma unroll - for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { - const int k_tile = k1 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], - &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - half2 tmp_h2[cpy_ne_D/2]; - ggml_cuda_memcpy_1( - tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); - - float2 tmp_f2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { - tmp_f2[i1] = __half22float2(tmp_h2[i1]); - } - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - half2 V_k[(D/2)/warp_size]; - half2 KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - half tmp[softmax_iter_j]; - ggml_cuda_memcpy_1( - &tmp, KQ[j][k0 + k1]); -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_k[j0+j1] = __half2half2(tmp[j1]); - } - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; - } - } - } -#else -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - float2 V_k[(D/2)/warp_size]; - float KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - ggml_cuda_memcpy_1( - &KQ_k[j0], KQ[j][k0 + k1]); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; - VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; - } - } - } -#endif // FAST_FP16_AVAILABLE - - __syncthreads(); - } - } - - - // Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - float KQ_max_new_j = fmaxf(KQ_max[j0], sink); - KQ_max_new_j = warp_reduce_max(KQ_max_new_j); - - const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); - KQ_max[j0] = KQ_max_new_j; - - const float val = expf(sink - KQ_max[j0]); - KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; - if (threadIdx.x == 0) { - KQ_sum[j0] += val; - } - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size].x *= KQ_max_scale; - VKQ[j0][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - } - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); - } - if (gridDim.y == 1) { -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; - } -#else - const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; - VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - - // Write back results: -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); - } - ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); - } -#endif // FAST_FP16_AVAILABLE - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int warp_size = 32; - - constexpr size_t nbytes_shared = 0; - -#ifdef GGML_USE_HIP - if constexpr (D <= 128) { - if (Q->ne[1] > 32) { - constexpr int cols_per_block = 64; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - } -#endif // GGML_USE_HIP - - if (Q->ne[1] > 16) { - constexpr int cols_per_block = 32; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - - constexpr int cols_per_block = 16; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); -} - -template -static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; case 64: { - launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); } break; case 128: { - launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; case 256: { - launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; default: { GGML_ABORT("Unsupported head size"); } break; } } - -void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_switch_head_size(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_switch_head_size(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 10dc22d1..2efc9cc8 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1,3 +1,1216 @@ #include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); + if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { + KQ_sum_add += val; + } + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + half2 V_k[(DVp/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + half tmp[KQ_cs]; + ggml_cuda_memcpy_1( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + float2 V_k[(DVp/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ncols1 == 1 || col_Q_0 + j < ne01) { + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + if (gridDim.y == 1) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; + } +#else + const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; +#pragma unroll + for (int i = 0; i < (DVp/2)/warp_size; ++i) { + VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; + VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= ne01) { + return; + } + + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 1848d088..7235f1b7 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0c8e7b3e..fe970ada 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; #endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int gqa_ratio = Q->ne[2] / K->ne[2]; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + const int cc = ggml_cuda_info().devices[device].cc; - // TODO: temporary until support is extended - // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 - if (K->ne[1] % FATTN_KQ_STRIDE != 0) { - return BEST_FATTN_KERNEL_NONE; - } - switch (K->ne[0]) { + case 40: case 64: - case 128: - case 256: - if (V->ne[0] != K->ne[0]) { - return BEST_FATTN_KERNEL_NONE; - } - break; case 80: case 96: + case 128: case 112: + case 256: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { - return BEST_FATTN_KERNEL_NONE; - } break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { return BEST_FATTN_KERNEL_NONE; } break; @@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; - - // If Turing tensor cores available, use them except for some cases with batch size 1: - if (turing_mma_available(cc)) { - best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // If Turing tensor cores available, use them: + if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { if (Q->ne[1] <= 2) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } } - if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; } } - return best; + return BEST_FATTN_KERNEL_MMA_F16; } - // Use kernels specialized for small batch sizes if possible: - if (Q->ne[1] <= 8 && can_use_vector_kernel) { - return BEST_FATTN_KERNEL_VEC; - } - - // For large batch sizes, use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc)) { + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } return BEST_FATTN_KERNEL_WMMA_F16; } - // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } return BEST_FATTN_KERNEL_TILE; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 00000000..a8b15ad7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 00000000..1da18105 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 00000000..bc65c723 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 00000000..10b330fa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 00000000..254b7d2e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 00000000..5caffac0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 00000000..90abb3b1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 00000000..7292c0aa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d410080f..81a986f3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,8 +3,17 @@ from glob import glob import os +HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576] + TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v}); +""" + SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec.cuh" @@ -51,6 +60,11 @@ def get_short_name(long_quant_name): for filename in glob("*.cu"): os.remove(filename) +for head_size_kq in HEAD_SIZES_KQ: + head_size_v = head_size_kq if head_size_kq != 576 else 512 + with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: + f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) + for type_k in TYPES_KV: for type_v in TYPES_KV: with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: @@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]: with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + for head_size_kq in HEAD_SIZES_KQ: + if head_size_kq == 40: + continue if head_size_kq != 576 and ncols2 == 16: continue if head_size_kq == 576 and ncols2 != 16: diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 0e2b1847..934aefdc 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index f8477a2e..d76cb519 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") From 53721d6309da37821e7013980b4b6fcb77738b7e Mon Sep 17 00:00:00 2001 From: sirus20x6 Date: Sun, 12 Oct 2025 00:15:00 -0500 Subject: [PATCH 052/104] ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (llama/16518) The previous SVE implementation for `ggml_vec_dot_f16_unroll` contained a bug due to a copy-paste error. The wrong variable was used in an FMA instruction, leading to incorrect results. This commit corrects the variable usage and improves the clarity of the code by renaming variables to avoid confusion. Co-authored-by: Aaron --- ggml/src/ggml-cpu/vec.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2751359c..d3834182 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -144,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int i = 0; i < np; i += ggml_f16_step) { ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst + ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements + ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); @@ -160,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); + ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); From 70eb30f28eadf9dd729248d6911f2b495b2b642e Mon Sep 17 00:00:00 2001 From: sirus20x6 Date: Sun, 12 Oct 2025 00:25:37 -0500 Subject: [PATCH 053/104] ggml : Fix FP16 ELU positive branch (llama/16519) Co-authored-by: Aaron --- ggml/src/ggml-cpu/vec.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index d3834182..65c7dfb6 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -820,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i]))); + const float v = GGML_CPU_FP16_TO_FP32(x[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v)); } } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } From be778c992fdb6e3e37b8f48feb84350f07dd4209 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Sun, 12 Oct 2025 21:53:35 +0800 Subject: [PATCH 054/104] fix UT fault cases: count-equal, argsort, pad OPs (llama/16521) * fix/refactor OP argsort, pad * fix count-equal op * update SYCL OP list * fix format issue --------- Co-authored-by: Zhang Jianyu --- ggml/src/ggml-sycl/backend.hpp | 2 + ggml/src/ggml-sycl/binbcast.cpp | 9 --- ggml/src/ggml-sycl/binbcast.hpp | 6 -- ggml/src/ggml-sycl/common.hpp | 3 +- ggml/src/ggml-sycl/count-equal.cpp | 79 +++++++++++++++++++++ ggml/src/ggml-sycl/count-equal.hpp | 9 +++ ggml/src/ggml-sycl/element_wise.cpp | 78 --------------------- ggml/src/ggml-sycl/element_wise.hpp | 2 - ggml/src/ggml-sycl/ggml-sycl.cpp | 103 +++++++++++++++++----------- ggml/src/ggml-sycl/pad.cpp | 97 ++++++++++++++++++++++++++ ggml/src/ggml-sycl/pad.hpp | 24 +++++++ 11 files changed, 276 insertions(+), 136 deletions(-) create mode 100644 ggml/src/ggml-sycl/count-equal.cpp create mode 100644 ggml/src/ggml-sycl/count-equal.hpp create mode 100644 ggml/src/ggml-sycl/pad.cpp create mode 100644 ggml/src/ggml-sycl/pad.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 410a67b0..6ff3215d 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -18,6 +18,7 @@ #include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "count-equal.hpp" #include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" @@ -28,6 +29,7 @@ #include "mmvq.hpp" #include "norm.hpp" #include "outprod.hpp" +#include "pad.hpp" #include "quantize.hpp" #include "quants.hpp" #include "rope.hpp" diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index e0a1de0f..0a3883ae 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } -inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sub(ctx, dst); } -void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - ggml_sycl_op_count_equal(ctx, dst); -} - void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp index 34c4064f..9cce0f05 100644 --- a/ggml/src/ggml-sycl/binbcast.hpp +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) { return a - b; } -static __dpct_inline__ float op_count_equal(const float a, const float b) { - return (a == b) ? 1.0f : 0.0f; -} - -void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - static __dpct_inline__ float op_mul(const float a, const float b) { return a * b; } diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index d66d7ade..338fa08c 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -195,7 +195,8 @@ struct optimize_feature { struct sycl_device_info { int cc; // compute capability - // int nsm; // number of streaming multiprocessors + int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum + // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support diff --git a/ggml/src/ggml-sycl/count-equal.cpp b/ggml/src/ggml-sycl/count-equal.cpp new file mode 100644 index 00000000..b0a8b482 --- /dev/null +++ b/ggml/src/ggml-sycl/count-equal.cpp @@ -0,0 +1,79 @@ +#include "count-equal.hpp" + +#include + +template +static void count_equal(const T *__restrict__ x, const T *__restrict__ y, + int64_t *__restrict__ dst, const int64_t dk, + const int64_t k) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk; + const int64_t i1 = sycl::min(i0 + dk, k); + + int nequal = 0; + + for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) { + const T xi = x[i]; + const T yi = y[i]; + nequal += xi == yi; + } + + nequal = warp_reduce_sum(nequal); + + if (item_ct1.get_local_id(2) != 0) { + return; + } + + dpct::atomic_fetch_add( + (int *)dst, nequal); +} + +void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT( dst->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int64_t * dst_d = (int64_t *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + const int id = get_current_device_id(); + const int nsm = ggml_sycl_info().devices[id].nsm; + + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int"); + const int64_t dne = + GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst)))); + + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums( + std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) / + SYCL_COUNT_EQUAL_CHUNK_SIZE), + 1, 1); + + switch (src0->type) { + case GGML_TYPE_I32: { + const int *src0_d = (const int *)src0->data; + const int *src1_d = (const int *)src1->data; + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + count_equal(src0_d, src1_d, dst_d, dne, ne); + GGML_UNUSED(item_ct1); + }); + + } break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/ggml/src/ggml-sycl/count-equal.hpp b/ggml/src/ggml-sycl/count-equal.hpp new file mode 100644 index 00000000..f7f4fcbd --- /dev/null +++ b/ggml/src/ggml-sycl/count-equal.hpp @@ -0,0 +1,9 @@ +#ifndef GGML_SYCL_COUNT_EQUAL_HPP +#define GGML_SYCL_COUNT_EQUAL_HPP +#include "common.hpp" + +#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128 + +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif //GGML_SYCL_COUNT_EQUAL_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index c2da2fb4..aeeb3875 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } -template -static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, - const sycl::nd_item<3> &item_ct1) { - int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2); - if (nidx >= ne0) { - return; - } - - // operation - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) { - int offset_src = nidx + item_ct1.get_group(1) * ne00 + - item_ct1.get_group(0) * ne00 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - dst[offset_dst] = static_cast(0.0f); - } -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, }); } -template -static void pad_sycl(const T *x, T *dst, const int ne00, - const int ne01, const int ne02, const int ne0, - const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); - sycl::range<3> gridDim(ne2, ne1, num_blocks); - stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); -} - template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) @@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx } } -template -static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], - (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], - (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} } // namespace ggml_sycl_detail @@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te }); } -static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst, - [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, - queue_ptr stream) { - ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_upscale(ctx, dst); } -void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_pad(ctx, dst); -} void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 50749e87..43474317 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e4cc3c8e..45b8c216 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + info.devices[i].nsm = prop.get_max_compute_units(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); - info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].smpbo = prop.get_local_mem_size(); + + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) { template __dpct_inline__ static void k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad, - const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) { + const int tasks_per_thread, const sycl::nd_item<3> &item_ct1, + uint8_t *dpct_local) { // bitonic sort - int col = item_ct1.get_local_id(2); + int col_index = item_ct1.get_local_id(2); int row = item_ct1.get_group(1); - if (col >= ncols_pad) { - return; + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + if (col >= ncols_pad) { + return; + } } const float * x_row = x + row * ncols; auto dst_row = (int *)dpct_local; // initialize indices - dst_row[col] = col; + for (int i=0;i 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && + (order == GGML_SORT_ORDER_ASC + ? x_row[dst_row[col]] > x_row[dst_row[ixj]] + : x_row[dst_row[col]] < + x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && + (order == GGML_SORT_ORDER_ASC + ? x_row[dst_row[col]] < x_row[dst_row[ixj]] + : x_row[dst_row[col]] > + x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } } } + item_ct1.barrier(sycl::access::fence_space::local_space); } - /* - DPCT1118:1: SYCL group functions and algorithms must be encountered - in converged control flow. You may need to adjust the code. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); } } // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } } - static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, const sycl::nd_item<3> &item_ct1) { const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + @@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) { static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, ggml_sort_order order, - queue_ptr stream) { + queue_ptr stream, int device) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); - const sycl::range<3> block_dims(1, 1, ncols_pad); + int nth = 1; + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + while (nth < ncols_pad && nth < max_block_size) + nth *= 2; + if (nth > max_block_size) + nth = max_block_size; + + const int tasks_per_thread = ncols_pad / nth; + + const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); @@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() + x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, + dpct_local_acc_ct1 + .get_multi_ptr() .get()); }); }); @@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() + x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, + dpct_local_acc_ct1 + .get_multi_ptr() .get()); }); }); @@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); + argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, + main_stream, ctx.device); } inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ACC: return true; case GGML_OP_PAD: - return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return ggml_is_contiguous(op->src[0]); case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp new file mode 100644 index 00000000..413712c5 --- /dev/null +++ b/ggml/src/ggml-sycl/pad.cpp @@ -0,0 +1,97 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +//#include "common.hpp" +#include "pad.hpp" + +static void pad_f32(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int i0 = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + int i1 = item_ct1.get_group(1); + int i2 = item_ct1.get_group(0) % ne2; + int i3 = item_ct1.get_group(0) / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + // operation + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + + i02 * (ne00 * ne01) + i01 * ne00 + i00; + + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = 0.0f; + } +} + +static void pad_f32_sycl(const float *src, float *dst, const int lp0, + const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, + const int rp3, const int ne0, const int ne1, + const int ne2, const int ne3, + dpct::queue_ptr stream) { + int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3); + stream->parallel_for( + sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, + ne2, ne3); + }); +} + +void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + + pad_f32_sycl(src0_d, dst_d, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); +} + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_pad(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/pad.hpp b/ggml/src/ggml-sycl/pad.hpp new file mode 100644 index 00000000..b099e9b7 --- /dev/null +++ b/ggml/src/ggml-sycl/pad.hpp @@ -0,0 +1,24 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_PAD_HPP +#define GGML_SYCL_PAD_HPP + +#include "common.hpp" + +#define SYCL_PAD_BLOCK_SIZE 256 + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_PAD_HPP From 7f22fe5d8fe3b821f0f329bd786d3daa0a0f0181 Mon Sep 17 00:00:00 2001 From: Sam/Samuel <57896620+cern1710@users.noreply.github.com> Date: Mon, 13 Oct 2025 02:43:14 +0800 Subject: [PATCH 055/104] metal : add opt_step_adamw and op_sum (llama/16529) * scaffold to support opt step adamw on metal (not written so far) * add opt-step-adamw kernel for metal * pass op->src[4] as a separate buffer to the pipeline * add bounds check to opt-step-adamw kernel * complete scaffold for GGML_OP_SUM * naive GGML_OP_SUM kernel * remove unwanted comment * change OP_SUM capability gate * Add has_simdgroup_reduction to both ops to pass CI --- ggml/src/ggml-metal/ggml-metal-device.cpp | 37 ++++++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 2 + ggml/src/ggml-metal/ggml-metal-device.m | 3 + ggml/src/ggml-metal/ggml-metal-impl.h | 8 +++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 68 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 2 + ggml/src/ggml-metal/ggml-metal.metal | 52 +++++++++++++++++ 7 files changed, 172 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e23abdda..335d5848 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -1482,3 +1501,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_ADAMW); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1034e4bb..283e70fa 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -134,6 +135,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 95279730..e38e7076 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: @@ -798,6 +799,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_OPT_STEP_ADAMW: + return has_simdgroup_reduction; default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index c9dff873..c4c9f0a7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -544,6 +544,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -773,4 +777,8 @@ typedef struct { uint64_t nb01; } ggml_metal_kargs_argmax; +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5f937044..c01c0b18 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_glu(ctx, idx); } break; + case GGML_OP_SUM: + { + n_fuse = ggml_metal_op_sum(ctx, idx); + } break; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: { @@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argmax(ctx, idx); } break; + case GGML_OP_OPT_STEP_ADAMW: + { + n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); @@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); + + ggml_metal_kargs_sum args = { + /*.np =*/ n, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + + return 1; +} + int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } + +int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_adamw args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index d4cb9446..6641cf5d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); @@ -78,6 +79,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); #ifdef __cplusplus } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ddc28504..780d6a97 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -8754,3 +8772,37 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} From 53e21364a6950a7193a211d9b53331e626c3fe78 Mon Sep 17 00:00:00 2001 From: hipudding Date: Mon, 13 Oct 2025 08:52:22 +0800 Subject: [PATCH 056/104] CANN: Update several operators to support FP16 data format (llama/16251) Many Ascend operators internally use FP16 precision for computation. If input data is in FP32, it must first be cast to FP16 before computation, and then cast back to FP32 after computation, which introduces unnecessary cast operations. Moreover, FP16 computation requires significantly less workload compared to FP32, leading to noticeable efficiency improvements. In this change, `get_rows`, `rms_norm`, and `flash_attn_ext` are extended to support multiple data types. Validation on the Qwen2 0.5b model shows correct accuracy and about 10% performance gain in concurrent scenarios. Co-authored-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 197 +++++++++++++++---------------- 1 file changed, 96 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 434023dd..240e8a1b 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, } /** - * @brief Get or expand a cached float32 tensor filled with a scalar value. + * @brief Get or expand a cached tensor filled with a scalar value. * - * This function manages cached device memory for float32 tensors. If the current + * This function manages cached device memory for tensors. If the current * cache size is insufficient for the requested tensor shape, the old memory will - * be released and new memory will be allocated. The allocated buffer is then - * initialized either with zeros (when @p value == 0.0f) or with the given scalar - * value using CANN operations. Finally, an aclTensor object is created from the - * cached memory and returned. + * be released and new memory will be allocated. The allocated buffer is + * initialized with the given scalar value using CANN operations. + * Finally, an aclTensor object is created from the cached memory and returned. * * @param ctx The CANN backend context that manages device memory. * @param buffer A pointer to the cached device buffer (will be allocated @@ -910,17 +909,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, * updated when the cache is expanded. * @param ne The tensor shape array (number of elements in each dimension). * @param nb The stride size for each dimension. + * @param dtype Data type of cached tensor. * @param dims The number of tensor dimensions. * @param value The scalar value used to fill the tensor (supports zero * initialization via memset or arbitrary values via fill_scalar). * @return An aclTensor pointer created from the cached buffer. */ -static aclTensor* get_f32_cache_acl_tensor( +static aclTensor* get_cache_acl_tensor( ggml_backend_cann_context& ctx, void** buffer, int64_t &cache_element, int64_t* ne, size_t* nb, + ggml_type dtype, int64_t dims, float value) { // Calculate total number of elements @@ -928,7 +929,7 @@ static aclTensor* get_f32_cache_acl_tensor( for (int i = 0; i < dims; i++) { n_element *= ne[i]; } - size_t size = n_element * sizeof(float); + size_t size = n_element * ggml_type_size(dtype); // Allocate or expand cache if needed if (cache_element < n_element) { @@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor( cache_element = n_element; // Initialize cache - if (value == 0.0f) { - ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream())); - } else { - int64_t pool_ne[1] = { n_element }; - size_t pool_nb[1] = { sizeof(float) }; - aclTensor* acl_value = ggml_cann_create_tensor( - *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1); - aclnn_fill_scalar(ctx, 1, acl_value); - ggml_cann_release_resources(ctx, acl_value); - } + int64_t pool_ne[1] = { n_element }; + size_t pool_nb[1] = { ggml_type_size(dtype) }; + aclTensor* acl_value = ggml_cann_create_tensor( + *buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), + pool_ne, pool_nb, 1); + aclnn_fill_scalar(ctx, value, acl_value); + ggml_cann_release_resources(ctx, acl_value); } - return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims); + return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), + ggml_type_size(dtype), ne, nb, dims); } void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - // build gamma, one... + // build gamma. size_t acl_gamma_nb[GGML_MAX_DIMS]; - acl_gamma_nb[0] = sizeof(float); + // gamma's type is the same with dst. + acl_gamma_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; } - aclTensor* acl_gamma = get_f32_cache_acl_tensor( + aclTensor* acl_gamma = get_cache_acl_tensor( ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, + dst->type, 1, // dims 1.0f // value ); - // build rstd, zero... + // build rstd. int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + // rstd will always be F32. acl_rstd_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; } - aclTensor* acl_rstd = get_f32_cache_acl_tensor( + aclTensor* acl_rstd = get_cache_acl_tensor( ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, acl_rstd_ne, acl_rstd_nb, + GGML_TYPE_F32, GGML_MAX_DIMS - 1, 0.0f // value ); @@ -1765,33 +1768,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // src ggml_tensor* src1 = dst->src[1]; // index + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + switch (src0->type) { - case GGML_TYPE_F32: { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - break; - } - case GGML_TYPE_F16: { - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + if(src0->type == dst->type) { + aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + } else { + aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator( + ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); + void* src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = dst->nb[0]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor* src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); + aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, + dst->data, dst->ne, dst->nb, + src1, dst->type); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), - src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); break; - } case GGML_TYPE_Q8_0: { // add 1 dim for bcast mul. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], @@ -1799,7 +1804,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] weight_ne[0] = QK8_0; weight_ne[1] = src0->ne[0] / QK8_0; @@ -1809,7 +1813,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; @@ -1819,18 +1822,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); ggml_cann_pool_alloc dequant_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - + ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type)); aclTensor* acl_weight_tensor = ggml_cann_create_tensor( src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); @@ -1838,16 +1838,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( - dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float), + dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - dequant_nb[0] = sizeof(float); + dequant_nb[0] = ggml_type_size(dst->type); dequant_ne = src0->ne; for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, dst->nb, @@ -1965,16 +1963,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { - int64_t acl_stride[2] = {1, transpose_ne[1]}; - - // Reverse ne. - std::reverse(transpose_ne, transpose_ne + n_dims); - - std::vector storageDims = {transpose_ne[0], transpose_ne[1]}; - - acl_weight_tensor = aclCreateTensor( - transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride, - 0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data); + acl_weight_tensor = + ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -3178,7 +3168,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr; - aclTensor* acl_dst_f16_tensor = nullptr; // Step 1: cast the src0 (Query) to fp16 if needed ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); @@ -3216,22 +3205,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - void* out_f16_buffer = out_f16_allocator.alloc( - ggml_nelements(dst) * faElemSize); - - int64_t* out_f16_ne = src0_bsnd_ne; - size_t out_f16_nb[GGML_MAX_DIMS]; - out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; - } - - acl_dst_f16_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS - ); - // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp aclTensor* bcast_pse_tensor = nullptr; @@ -3334,8 +3307,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; - // Step 5: launch the FusedInferAttentionScoreV2 kernel. - // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + aclTensor * fa_dst_tensor = nullptr; + aclTensor * acl_dst_tensor = nullptr; + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + if (dst->type == GGML_TYPE_F32) { + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0_bsnd_ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + fa_dst_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + } + else { + fa_dst_tensor = ggml_cann_create_tensor(dst); + } GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v @@ -3357,23 +3351,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ blockSize, antiquantMode, // blockSize, antiquantMode softmaxLseFlag, // softmaxLseFlag keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_dst_f16_tensor, // attentionOut + fa_dst_tensor, // attentionOut nullptr // softmaxLse ); - // Step 6: post-processing, permute and cast to f32 - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - // TODO: when dst is fp16, don't need cast - aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_src1_f16_tensor, - acl_src2_f16_tensor, - acl_dst_f16_tensor, - acl_dst_tensor); - if(src3 != nullptr){ - ggml_cann_release_resources(ctx, bcast_pse_tensor); + if (dst->type == GGML_TYPE_F32) { + // Step 6: post-processing, permute and cast to f32 + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); } - }else{ + + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + fa_dst_tensor, + acl_dst_tensor, + bcast_pse_tensor); + + } else { GGML_ABORT("Function is not implemented."); } } From ccac1b4772c7229405a9aa3b4e4c93bdbbcc7bea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Oct 2025 11:22:27 +0300 Subject: [PATCH 057/104] ggml : fix scalar path for computing norm (llama/16558) --- ggml/src/ggml-cpu/vec.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index b8e37052..43dc7537 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa #endif for (; i < n; ++i) { float val = x[i] - mean; + y[i] = val; val *= val; sum += (ggml_float)val; - y[i] = val; } return sum/n; } From bfd88b8b6ee31ce7733e42f6107f0731a83cadc7 Mon Sep 17 00:00:00 2001 From: Sam/Samuel <57896620+cern1710@users.noreply.github.com> Date: Mon, 13 Oct 2025 16:25:02 +0800 Subject: [PATCH 058/104] metal: add support for opt_step_sgd (llama/16539) * metal: add support for opt_step_sgd * add newline to pass EditorConfig check --- ggml/src/ggml-metal/ggml-metal-device.cpp | 19 ++++++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 4 +++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 38 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 14 +++++++++ 7 files changed, 78 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 335d5848..866cd2da 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1519,3 +1519,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ return res; } + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_SGD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 283e70fa..28ae2e17 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -136,6 +136,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index e38e7076..fc508304 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -800,6 +800,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: return has_simdgroup_reduction; default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index c4c9f0a7..a448c14f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -781,4 +781,8 @@ typedef struct { int64_t np; } ggml_metal_kargs_opt_step_adamw; +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c01c0b18..a61ea8fb 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); } break; + case GGML_OP_OPT_STEP_SGD: + { + n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); @@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { return 1; } + +int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_sgd args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 6641cf5d..f3527386 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -80,6 +80,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); #ifdef __cplusplus } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 780d6a97..74a9aa99 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -8806,3 +8806,17 @@ kernel void kernel_opt_step_adamw_f32( x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; } + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} From 417ecdddc5a1919094965d85bdf4caa604d26288 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Mon, 13 Oct 2025 17:01:24 +0800 Subject: [PATCH 059/104] CANN: fix CPU memory leak in CANN backend (llama/16549) This commit fixes a CPU-side memory leak issue in the CANN backend, which occurred when intermediate aclTensorList objects were not properly released after operator execution. The leak happened during repeated invocations of CANN ops (e.g., FlashAttention), leading to increasing host memory usage over time. Proper resource cleanup (aclDestroyTensorList and related release logic) has been added to ensure that all temporary tensors are correctly freed. --- ggml/src/ggml-cann/aclnn_ops.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 240e8a1b..2857e080 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated( unary_op(ctx, acl_src0, acl_dst); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1); - ggml_cann_release_resources(ctx, acl_src0, acl_dst); - if(src1) - ggml_cann_release_resources(ctx, acl_src1); + ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst); } /** @@ -1851,7 +1849,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { dst->data, dst->ne, dst->nb, src1, dst->type); - ggml_cann_release_resources(ctx, dequant_tensor); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); break; } default: @@ -3290,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_q_tensor = acl_src0_f16_tensor; aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; - auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); - auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); int64_t numHeads = src0->ne[2]; // N int64_t numKeyValueHeads = src1->ne[2]; @@ -3362,8 +3360,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_src1_f16_tensor, - acl_src2_f16_tensor, + acl_k_tensor_list, + acl_v_tensor_list, fa_dst_tensor, acl_dst_tensor, bcast_pse_tensor); From 8a9c2ba6a1f6b6d97dc86edb78362aa57c328abf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jie=20Fu=20=28=E5=82=85=E6=9D=B0=29?= Date: Mon, 13 Oct 2025 20:48:47 +0800 Subject: [PATCH 060/104] ggml : fix build broken with -march=armv9-a on MacOS (llama/16520) * ggml : fix build broken with -march=armv9-a on MacOS Signed-off-by: Jie Fu * Add #pragma message Signed-off-by: Jie Fu * Address review comment. Signed-off-by: Jie Fu * Update ggml/src/ggml-cpu/ggml-cpu.c --------- Signed-off-by: Jie Fu Co-authored-by: Diego Devesa --- ggml/src/ggml-cpu/ggml-cpu-impl.h | 2 +- ggml/src/ggml-cpu/ggml-cpu.c | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 799e2b11..713bf85e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,7 +68,7 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__ARM_FEATURE_SVE) +#if defined(__ARM_FEATURE_SVE) && defined(__linux__) #include #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index eded6eb7..ba2a36d9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -689,8 +689,13 @@ bool ggml_is_numa(void) { #endif static void ggml_init_arm_arch_features(void) { -#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__linux__) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#else + // TODO: add support of SVE for non-linux systems +#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." +#endif #endif } From 77272fe0df319dbf30cac3515c7eba55622dd411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 13 Oct 2025 16:29:45 +0200 Subject: [PATCH 061/104] CUDA: fix numerical issues in tile FA kernel (llama/16540) --- ggml/src/ggml-cuda/fattn-tile.cuh | 44 ++++++++++++------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 2efc9cc8..2b60b3bb 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); } - KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ? - slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; - KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } } KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); @@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]); - if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) { - KQ_sum_add += val; - } + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; } KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; @@ -975,26 +976,6 @@ static __global__ void flash_attn_tile( } } - if (gridDim.y == 1) { -#pragma unroll - for (int jc0 = 0; jc0 < cpw; ++jc0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]); -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv; - } -#else - const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0]; -#pragma unroll - for (int i = 0; i < (DVp/2)/warp_size; ++i) { - VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv; - VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - // Write back results: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { @@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile( return; } + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; #ifdef FAST_FP16_AVAILABLE @@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; } if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); @@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } ggml_cuda_memcpy_1( &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); From 66b0fc2fb7059833dc3882d1b07c85d5f3f61a76 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 13 Oct 2025 11:50:37 -0700 Subject: [PATCH 062/104] opencl: fix build targeting CL 2 (llama/16554) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 79d21487..d2759069 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); if (opencl_c_version.major >= 3) { + // Assume it is not available for 3.0, since it is optional in 3.0. + // If compiling against 3.0, then we can query. + backend_ctx->non_uniform_workgroups = false; +#if CL_TARGET_OPENCL_VERSION >= 300 CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), &backend_ctx->non_uniform_workgroups, 0)); +#endif } else { GGML_ASSERT(opencl_c_version.major == 2); // Non-uniform workgroup sizes is mandatory feature in v2.x. From 25ac94a6cb63093196a5ad366ff0c1e24cb23b7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Oct 2025 23:07:57 +0300 Subject: [PATCH 063/104] metal : FA support F32 K and V and head size = 32 (llama/16531) * metal : FA support F32 K and V and head size = 32 * graph : remove obsolete comment [no ci] --- ggml/src/ggml-metal/ggml-metal-device.m | 3 +- ggml/src/ggml-metal/ggml-metal.metal | 152 ++++++++++++++++-------- 2 files changed, 105 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fc508304..c3fe8f4e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -693,7 +693,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here - if (op->src[0]->ne[0] != 40 && + if (op->src[0]->ne[0] != 32 && + op->src[0]->ne[0] != 40 && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 80 && op->src[0]->ne[0] != 96 && diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 74a9aa99..1029cf8f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5213,8 +5213,30 @@ kernel void kernel_flash_attn_ext( half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 +#define FA_TYPES_F32 \ + half, half4, simdgroup_half8x8, \ + float, float4x4, simdgroup_float8x8, \ + float, float4x4, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 + typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5227,6 +5249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5239,6 +5262,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif +template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5250,6 +5274,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5261,6 +5286,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5272,6 +5298,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5283,6 +5310,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -5818,77 +5846,103 @@ kernel void kernel_flash_attn_ext_vec( float, float4, \ float4 +#define FA_TYPES_F32 \ + half4, \ + float4, \ + float4, \ + float, \ + float, float4, \ + float4 + typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES From a12848e8e9ac10286d976253e5c0eb8e382ea36d Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Tue, 14 Oct 2025 09:53:49 +0000 Subject: [PATCH 064/104] cuda : remove legacy copy-op pointer indirection code (llama/16485) * remove legacy copy-op pointer indirection code * further removal of copy-op indirection code * renamed check_node_graph_compatibility_and_refresh_copy_ops function --- ggml/src/ggml-cuda/common.cuh | 7 - ggml/src/ggml-cuda/cpy.cu | 218 ++++++++------------------------ ggml/src/ggml-cuda/cpy.cuh | 6 +- ggml/src/ggml-cuda/ggml-cuda.cu | 33 +---- 4 files changed, 58 insertions(+), 206 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e0abde54..41ff89c4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -944,13 +944,6 @@ struct ggml_cuda_graph { bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; std::vector ggml_graph_properties; - bool use_cpy_indirection = false; - std::vector cpy_dest_ptrs; - char ** dest_ptrs_d; - int dest_ptrs_size = 0; - // Index to allow each cpy kernel to be aware of it's position within the graph - // relative to other cpy nodes. - int graph_cpynode_index = -1; #endif }; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 746f4396..12d5bf77 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,18 +8,16 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); template -static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int } template -static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int nb12, const int nb13) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int cpy_blck(cx + x_offset, cdst + dst_offset); } -// Copy destination pointers to GPU to be available when pointer indirection is in use - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers - CUDA_CHECK(cudaStreamSynchronize(stream)); - if (cuda_graph->dest_ptrs_d != nullptr) { - CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); - } - CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); - cuda_graph->dest_ptrs_size = host_dest_ptrs_size; - } - // copy destination pointers to GPU - CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); - cuda_graph->graph_cpynode_index = 0; // reset index -#else - GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream); -#endif -} - template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( @@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( @@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( @@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( @@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + cudaStream_t stream) { const int num_blocks = ne; cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - char ** dest_ptrs_d = nullptr; - int graph_cpynode_index = -1; -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; - graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) @@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - if (src0->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } else { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); - } + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { - ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; - } -#else - GGML_UNUSED(disable_indirection_for_this_node); -#endif - } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - bool disable_indirection = true; - ggml_cuda_cpy(ctx, src0, dst, disable_indirection); -} - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { - if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - // Prioritize CUDA graph compatibility over direct memory copy optimization. - // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. - if (src0->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - return nullptr; - } - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK4_1>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - return (void*) cpy_f32_q; - } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_q_f32, QK5_1>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - return (void*) cpy_flt>; - } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; - } else { - GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - } + ggml_cuda_cpy(ctx, src0, dst); } diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 0bd3c0c6..a7a87d8f 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,10 +2,6 @@ #define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); - -void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); - -void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 856e9de2..83b82c1a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, +static bool check_node_graph_compatibility(ggml_cgraph * cgraph, bool use_cuda_graph) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; @@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_CPY) { - - // Store the pointers which are updated for each token, such that these can be sent - // to the device and accessed using indirection from CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); - - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (!ptr) { - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); -#endif - } - } - if (!use_cuda_graph) { break; } } - if (use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = true; - // copy pointers to GPU so they can be accessed via indirection within CUDA graph - ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); - } - return use_cuda_graph; } @@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) { return false; } @@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW ) { return false; @@ -3120,7 +3095,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { @@ -3147,10 +3122,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - if (!use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = false; - } - #else bool use_cuda_graph = false; bool cuda_graph_update_required = false; From b4c5c6f71fec4c04a1805df73588f52d316f31ce Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 19:15:15 +0800 Subject: [PATCH 065/104] CUDA: add fp kernel for larger batch size MoE (llama/16512) * CUDA: kernel for larger batch sizes for MoE * WIP * WIP * WIP * WIP * WIP * WIP * fixup * tests * Move mmq_ids_helper to mmid * cleanup * Remove redundant checks --- ggml/src/ggml-cuda/mmf.cu | 46 ++++- ggml/src/ggml-cuda/mmf.cuh | 344 ++++++++++++++++++++++++++++++++---- ggml/src/ggml-cuda/mmid.cu | 164 +++++++++++++++++ ggml/src/ggml-cuda/mmid.cuh | 5 + ggml/src/ggml-cuda/mmq.cu | 169 +----------------- 5 files changed, 525 insertions(+), 203 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmid.cu create mode 100644 ggml/src/ggml-cuda/mmid.cuh diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 599e085e..9e2aaf52 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,5 +1,7 @@ #include "ggml.h" #include "mmf.cuh" +#include "mmid.cuh" + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mmf_ids_data ids_info{}; + mmf_ids_data * ids_info_ptr = nullptr; + ggml_cuda_pool_alloc ids_src_compact_dev; + ggml_cuda_pool_alloc ids_dst_compact_dev; + ggml_cuda_pool_alloc expert_bounds_dev; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr nchannels_y = ids->ne[0]; } + if (ids && ncols_dst > 16) { + const int64_t n_expert_used = ids->ne[0]; + const int64_t n_experts = ne02; + const int64_t n_tokens = ne12; + const int64_t ne_get_rows = n_tokens * n_expert_used; + + ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); + ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); + expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); + + const int si1 = static_cast(ids_s1); + const int sis1 = static_cast(src1->nb[2] / src1->nb[1]); + + GGML_ASSERT(sis1 > 0); + + ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ids_info.ids_src_compact = ids_src_compact_dev.get(); + ids_info.ids_dst_compact = ids_dst_compact_dev.get(); + ids_info.expert_bounds_dev = expert_bounds_dev.get(); + ids_info.n_experts = static_cast(n_experts); + ids_info.sis1 = sis1; + ids_info_ptr = &ids_info; + } + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; @@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; @@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; @@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } if (mul_mat_id) { - if (type == GGML_TYPE_F32 && src1_ncols > 32) { + if (src0_ne[1] <= 1024 && src1_ncols > 512) { return false; - } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + } else if(src0_ne[1] > 1024 && src1_ncols > 128) { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a6c3adfc..49d5295b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,14 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +struct mmf_ids_data { + const int32_t * ids_src_compact = nullptr; + const int32_t * ids_dst_compact = nullptr; + const int32_t * expert_bounds_dev = nullptr; + int n_experts = 0; + int sis1 = 0; +}; + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); @@ -224,6 +232,250 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } + +//This kernel is for larger batch sizes of mul_mat_id +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = blockIdx.y; + const int expert_start = expert_bounds[expert_idx]; + const int expert_end = expert_bounds[expert_idx + 1]; + const int ncols_expert = expert_end - expert_start; + + const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; + const int tile_idx = blockIdx.z; + if (tile_idx >= tiles_for_expert) { + return; + } + + const int col_base = tile_idx * cols_per_block; + + GGML_UNUSED(channel_ratio); + + const int channel_x = expert_idx; + const int sample_dst = 0; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; + y += int64_t(sample_y) *stride_sample_y; + dst += int64_t(sample_dst)*stride_sample_dst; + + const int32_t * ids_src_expert = ids_src_compact + expert_start; + const int32_t * ids_dst_expert = ids_dst_compact + expert_start; + + extern __shared__ char data_mmv[]; + char * compute_base = data_mmv; + + //const float2 * y2 = (const float2 *) y; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + + if constexpr (std::is_same_v) { + float vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float val = 0.0f; + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + val = y[channel*stride_channel_y + token*stride_col_y + col]; + } + } + vals[j0] = val; + } + }; + + gather_tile(0, vals_buf[0]); + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + float2 vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float2 *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + } + } + vals[j0] = tmp; + } + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); + } + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const float2 tmp = vals_buf[curr_buf][j0]; + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + const int global_j = col_base + j; + if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { + const int dst_entry = ids_dst_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd); + const int token = (int) qrm.x; + if (token < ncols_dst_total) { + const int slot = (int) qrm.y; + dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids( const int64_t stride_col_id, const int64_t stride_row_id, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { - if (ids) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, + const mmf_ids_data * ids_data) { + const bool has_ids_data = ids_data && ids_data->ids_src_compact; + + // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) + // we prefer the normal mul_mat_f path with has_ids=true. + if (has_ids_data && ncols_dst > 16) { + const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); + if (max_tiles == 0) { + return; + } + dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); + + const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); + const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); + + mul_mat_f_ids<<>> + (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, + ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd); + } else if (ids) { const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; + mul_mat_f<<>> - (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { @@ -258,7 +532,7 @@ void mul_mat_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; @@ -290,7 +564,7 @@ void mul_mat_f_cuda( const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; - const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_dims(warp_size, nwarps_best, 1); @@ -300,49 +574,57 @@ void mul_mat_f_cuda( mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 2: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 3: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 4: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 5: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 6: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 7: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 8: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; @@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block( case 1: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ - cudaStream_t stream); + cudaStream_t stream, const mmf_ids_data * ids_data); #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ diff --git a/ggml/src/ggml-cuda/mmid.cu b/ggml/src/ggml-cuda/mmid.cu new file mode 100644 index 00000000..3c61e459 --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cu @@ -0,0 +1,164 @@ +#include "common.cuh" +#include "mmid.cuh" + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mm_ids_helper_store { + uint32_t data; + + __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mm_ids_helper[]; + mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mm_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mm_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mm_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh new file mode 100644 index 00000000..ac090aea --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cuh @@ -0,0 +1,5 @@ +#pragma once + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 12bdc629..a2c8760a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,141 +1,6 @@ #include "mmq.cuh" #include "quantize.cuh" - -#include - -// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. -struct mmq_ids_helper_store { - uint32_t data; - - __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { - data = (it & 0x003FFFFF) | (iex_used << 22); - } - - __device__ uint32_t it() const { - return data & 0x003FFFFF; - } - - __device__ uint32_t iex_used() const { - return data >> 22; - } -}; -static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); - -// Helper function for mul_mat_id, converts ids to a more convenient format. -// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. -// ids_dst describes the same mapping but for the dst tensor. -// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. -template -__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; - const int expert = blockIdx.x; - - extern __shared__ char data_mmq_ids_helper[]; - mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; - - int nex_prev = 0; // Number of columns for experts with a lower index. - int it_compact = 0; // Running index for the compact slice of this expert. - - if constexpr (n_expert_used_template == 0) { - // Generic implementation: - for (int it = 0; it < n_tokens; ++it) { - int iex_used = -1; // The index at which the expert is used, if any. - for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { - const int expert_used = ids[it*si1 + iex]; - nex_prev += expert_used < expert; - if (expert_used == expert) { - iex_used = iex; - } - } - - if (iex_used != -1) { - store[it_compact] = mmq_ids_helper_store(it, iex_used); - } - - if (warp_reduce_any(iex_used != -1)) { - it_compact++; - } - } - } else { - // Implementation optimized for specific numbers of experts used: - static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); - const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. - for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { - const int it = it0 + threadIdx.x / neu_padded; - - const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. - const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? - ids[it*si1 + iex] : INT_MAX; - const int iex_used = expert_used == expert ? iex : -1; - nex_prev += expert_used < expert; - - // Whether the threads at this token position have used the expert: - const int it_compact_add_self = warp_reduce_any(iex_used != -1); - - // Do a scan over threads at lower token positions in warp to get the correct index for writing data: - int it_compact_add_lower = 0; -#pragma unroll - for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { - const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); - if (threadIdx.x >= static_cast(offset)) { - it_compact_add_lower += tmp; - } - } - - if (iex_used != -1) { - store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); - } - - // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: - it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); - } - } - nex_prev = warp_reduce_sum(nex_prev); - - for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { - const mmq_ids_helper_store store_it = store[itc]; - const int it = store_it.it(); - const int iex_used = store_it.iex_used(); - ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; - ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; - } - - if (threadIdx.x != 0) { - return; - } - - expert_bounds[expert] = nex_prev; - - if (expert < static_cast(gridDim.x) - 1) { - return; - } - - expert_bounds[gridDim.x] = nex_prev + it_compact; -} - -template -static void launch_mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { - GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); - GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); - - const int id = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[id].warp_size; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); - - const dim3 num_blocks(n_experts, 1, 1); - const dim3 block_size(warp_size, 1, 1); - const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); - GGML_ASSERT(nbytes_shared <= smpbo); - mmq_ids_helper<<>> - (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); -} +#include "mmid.cuh" static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { @@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q( const int si1 = ids->nb[1] / ggml_element_size(ids); const int sis1 = nb12 / nb11; - switch (n_expert_used) { - case 2: - launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 4: - launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 6: - launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 8: - launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 16: - launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 32: - launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - default: - launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - } + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); CUDA_CHECK(cudaGetLastError()); } From f2075667fa872b95b8afa3517f938432ffb488ba Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 19:16:21 +0800 Subject: [PATCH 066/104] CUDA: use fastdiv + ggml_cuda_mad for mmvf (llama/16557) * CUDA: use fastdiv + ggml_cuda_mad for mmvf * use bf16 directly + fix formatting * Add exception for HIP code --- ggml/src/ggml-cuda/mmvf.cu | 72 +++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 5b21ef05..57ab8393 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -7,14 +7,14 @@ template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int row = blockIdx.x; const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; const int tid = threadIdx.x; @@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x*tmpy.x; - sumf[j] += tmpx.y*tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else if constexpr (std::is_same_v) { @@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x * tmpy.x; - sumf[j] += tmpx.y * tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else { @@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 +#if defined(GGML_USE_HIP) const int * x2 = (const int *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); } } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + } + } +#endif } else { static_assert(std::is_same_v, "unsupported type"); } @@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda( GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda( case 32: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); From 1bdd746bc8989733075c6e321e517b4ef0f6c203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 14 Oct 2025 14:22:47 +0200 Subject: [PATCH 067/104] CUDA: enable FA for FP32 KV cache (llama/16546) --- ggml/src/ggml-cuda/fattn-vec.cuh | 9 ++------- ggml/src/ggml-cuda/fattn.cu | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 89ab0f16..e1838fdd 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); const int nwarps = nthreads / WARP_SIZE; fattn_kernel_t fattn_kernel = flash_attn_ext_vec; - constexpr bool need_f16_K = false; - constexpr bool need_f16_V = false; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; constexpr size_t nbytes_shared = 0; launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } @@ -526,11 +526,6 @@ template void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index fe970ada..7dee032c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } } -#define FATTN_VEC_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ - return; \ - } \ +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ + } \ #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ FATTN_VEC_CASE( 64, type_K, type_V) \ @@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS switch (K->type) { + case GGML_TYPE_F32: case GGML_TYPE_F16: break; case GGML_TYPE_Q4_1: @@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If Turing tensor cores available, use them: if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { return BEST_FATTN_KERNEL_VEC; } @@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { - if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { if (Q->ne[1] == 1) { if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_VEC; From 73e200ee851bc74aede125ab3acf9edda6f3e9f7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Oct 2025 07:51:36 -0500 Subject: [PATCH 068/104] vulkan: Improve build time for MSVC (llama/16545) Enable CMP0147 so custom build steps (invoking vulkan-shader-gen) are run in parallel. Enable /MP so source files are compiled in parallel. --- ggml/src/ggml-vulkan/CMakeLists.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 83a83887..de01336c 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,9 +1,18 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) cmake_policy(SET CMP0116 NEW) +if (POLICY CMP0147) + # Parallel build custom build steps + cmake_policy(SET CMP0147 NEW) +endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Parallel build object files + add_definitions(/MP) +endif() + function(detect_host_compiler) if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) From 393fbbc80b616de3f47afe7a1d3a96f034058452 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Oct 2025 08:53:37 -0500 Subject: [PATCH 069/104] vulkan: Support FA with K/V in F32 (llama/16543) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 +++++++++++++-- .../vulkan-shaders/dequant_funcs_cm2.glsl | 14 +++++++++++++ .../vulkan-shaders/flash_attn_base.glsl | 20 ++++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++----- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cd89c71..1674dc66 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) @@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) @@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); - const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); - const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + // For F32, the shader treats it as a block of size 4 (for vec4 loads) + if (k->type == GGML_TYPE_F32) { + k_stride /= 4; + } + if (v->type == GGML_TYPE_F32) { + v_stride /= 4; + } uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); bool aligned = (KV % alignment) == 0 && @@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } switch (op->src[1]->type) { case GGML_TYPE_F16: + case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: // supported in scalar and coopmat2 paths diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 6a5bb457..67baedf7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,6 +1,18 @@ #include "types.glsl" +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { + vec4 block; +}; + +float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const vec4 v = bl.block; + const uint idx = coordInBlock[1]; + const f16vec4 vf16 = f16vec4(v); + return vf16[idx]; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_F32) +#define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9b1f153b..eb93903c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; -#if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 +#if defined(DATA_A_F32) +layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; +#elif defined(A_TYPE_PACKED16) layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(DATA_A_F32) +#undef BLOCK_SIZE +#define BLOCK_SIZE 4 +#define BLOCK_BYTE_SIZE 16 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + // iqs is currently always zero in the flash attention shaders + if (binding_idx == BINDING_IDX_K) { + return k_packed.k_data_packed[a_offset + ib]; + } else { + return v_packed.v_data_packed[a_offset + ib]; + } +} +#endif + #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f0cc24ff..184f3f3a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -611,9 +611,6 @@ void process_shaders() { } for (const auto& tname : type_names) { - if (tname == "f32") { - continue; - } if (tname == "bf16") continue; #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -630,7 +627,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); @@ -639,7 +636,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); From 2eb9119754efe2a7a7ece560fd87d0b42044f262 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 22:48:08 +0800 Subject: [PATCH 070/104] CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (llama/16577) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 83b82c1a..da312992 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2876,7 +2876,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d2759069..0693d38d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2686,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx // if rms_norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && - !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } From 499f183e751671075127b4b34ae71ea190e3eda7 Mon Sep 17 00:00:00 2001 From: SavicStefan <50296686+SavicStefan@users.noreply.github.com> Date: Tue, 14 Oct 2025 19:18:05 +0200 Subject: [PATCH 071/104] vulkan: Add ACC_TYPE_VEC2 implementation (llama/16203) Signed-off-by: Stefan Savic Co-authored-by: Stefan Savic --- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 85400ac5..a20788c4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -313,12 +313,12 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE sums[WMITER * TM * WNITER * TN]; + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b[TN]; + FLOAT_TYPE_VEC2 cache_b; - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = ACC_TYPE(0.0f); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); } #endif @@ -360,20 +360,22 @@ void main() { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); } } } } + } #endif @@ -388,8 +390,9 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX); + sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif #endif @@ -463,14 +466,21 @@ void main() { const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID - [[unroll]] for (uint cr = 0; cr < TM; cr++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; #ifdef MUL_MAT_ID - if (dr_warp + cr < p.M) { - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #else - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #endif // MUL_MAT_ID } From ff2253b08abb96372f5b45350aaf69ab9fbd514d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Oct 2025 22:08:53 +0300 Subject: [PATCH 072/104] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b84ddf48..524e2b1c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -fcc2a5c0cfd81ee0517ee42f1acdc371ec92d598 +c538174d261d8172480f87efcfec8e69aac13ebb From 8ba3c13b0c3e7f35deb324cc5d1d864e4c591cfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Oct 2025 22:09:02 +0300 Subject: [PATCH 073/104] talk-llama : sync llama.cpp --- examples/talk-llama/llama-graph.cpp | 117 ++++++++++++++++++---------- examples/talk-llama/llama-graph.h | 10 ++- examples/talk-llama/llama-model.cpp | 11 ++- examples/talk-llama/llama.cpp | 1 + 4 files changed, 87 insertions(+), 52 deletions(-) diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index a24853c6..f29a1e98 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); - const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : - (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : - (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : - (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + const char * swa_type_str = "unknown"; + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break; + case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break; + case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break; + case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; + }; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); + const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + for (int h = 0; h < 1; ++h) { + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - float * data = (float *) kq_mask->data; + const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; - // [TAG_NO_CACHE_ISWA] - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); - - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - - for (int i0 = 0; i0 < n_tokens; ++i0) { - float f = -INFINITY; - - for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { + for (int i0 = 0; i0 < n_tokens; ++i0) { const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; + // mask different sequences if (s0 != s1) { - continue; // skip different sequences + continue; } - if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { - continue; // skip future tokens for causal attention + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; } - // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] - //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { - // continue; // skip masked tokens for SWA - //} - - // TODO: reimplement this like in llama_kv_cache_unified - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } + }; + + { + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + + float * data = (float *) self_kq_mask->data; + + std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); + + fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); + + if (debug) { + print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + } } - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(self_kq_mask_swa); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); + + float * data = (float *) self_kq_mask_swa->data; + + std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); + + fill_mask(data, hparams.n_swa, hparams.swa_type); + + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } } @@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha( k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_kv = k->ne[1]; - ggml_tensor * cur; - // TODO: replace hardcoded padding with ggml-provided padding - if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { + if (cparams.flash_attn && kq_b == nullptr) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); - ggml_set_input(inp->kq_mask); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask); - inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } else { + inp->self_kq_mask_swa = nullptr; + inp->self_kq_mask_swa_cnv = nullptr; + } return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } @@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto & kq_mask = inp->get_kq_mask(); + const bool is_swa = hparams.is_swa(il); + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index dc84b794..d0c3934f 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -257,10 +257,14 @@ public: void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] + // n_tokens == n_batch + ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 36d495d6..0cdad9ba 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; -struct llm_build_gemma_embedding_iswa : public llm_graph_context { - llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_gemma_embedding : public llm_graph_context { + llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] - auto * inp_attn = build_attn_inp_kv_iswa(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: - //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: @@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA_EMBEDDING: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER2: { diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index fe5a7a83..38700f97 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits( LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); return nullptr; } + splits.reserve(n_paths); for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } From 4979e04f5dcaccb36057e059bbaed8a2f5288315 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Oct 2025 10:29:42 +0300 Subject: [PATCH 074/104] release : v1.8.2 --- CMakeLists.txt | 2 +- bindings/javascript/package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 91b9d0a9..517f30bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.1) +project("whisper.cpp" VERSION 1.8.2) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index ae601157..37bc7509 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.1", + "version": "1.8.2", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 23c19308d8a5786c65effa4570204a881660ff31 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 20 Oct 2025 15:39:48 +0300 Subject: [PATCH 075/104] server : set no_context == true (#3482) --- examples/server/server.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1262c3d6..1d49aa3b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -103,7 +103,7 @@ struct whisper_params { bool use_gpu = true; bool flash_attn = true; bool suppress_nst = false; - bool no_context = false; + bool no_context = true; bool no_language_probabilities = false; std::string language = "en"; @@ -176,7 +176,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); @@ -240,7 +239,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } - else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; } else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params @@ -572,10 +570,6 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content); } - if (req.has_file("no_context")) - { - params.no_context = parse_str_to_bool(req.get_file_value("no_context").content); - } if (req.has_file("vad")) { params.vad = parse_str_to_bool(req.get_file_value("vad").content); From 8ed913da0e545e9b547a1800c67879aef67ac68f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Oct 2025 20:33:05 +0300 Subject: [PATCH 076/104] metal : avoid using Metal's gpuAddress property (llama/16576) * metal : avoid using Metal's gpuAddress property * metal : fix rope kernels buffer check --- ggml/src/ggml-metal/ggml-metal-device.m | 24 ++++++++++++++---------- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 8 ++++---- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3fe8f4e..553cf8f5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -7,6 +7,8 @@ #include +#include + #ifndef TARGET_OS_VISION #define TARGET_OS_VISION 0 #endif @@ -22,6 +24,9 @@ // overload of MTLGPUFamilyMetal3 (not available in some environments) static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; +// virtual address for GPU memory allocations +static atomic_uintptr_t g_addr_device = 0x000000400ULL; + #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -827,7 +832,7 @@ struct ggml_metal_buffer_wrapper { }; struct ggml_metal_buffer { - void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 + void * all_data; size_t all_size; // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host @@ -965,14 +970,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, if (shared) { res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; - res->owned = true; } else { - // dummy, non-NULL value - we'll populate this after creating the Metal buffer below - res->all_data = (void *) 0x000000400ULL; + // use virtual address from g_addr_device counter + res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; + res->owned = true; + res->device = ggml_metal_device_get_obj(dev); res->queue = ggml_metal_device_get_queue(dev); @@ -983,15 +989,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->buffers[0].metal = nil; if (size_aligned > 0) { - if (props_dev->use_shared_buffers &&shared) { + if (props_dev->use_shared_buffers && shared) { res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; } else { res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; - - res->all_data = (void *) (res->buffers[0].metal.gpuAddress); } } @@ -1139,7 +1143,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { if (buf->is_shared) { - memset((char *)tensor->data + offset, value, size); + memset((char *) tensor->data + offset, value, size); return; } @@ -1168,7 +1172,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy((char *)tensor->data + offset, data, size); + memcpy((char *) tensor->data + offset, data, size); return; } @@ -1223,7 +1227,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { if (buf->is_shared) { - memcpy(data, (const char *)tensor->data + offset, size); + memcpy(data, (const char *) tensor->data + offset, size); return; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index a448c14f..fa2d82ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -251,6 +251,7 @@ typedef struct { int32_t sect_1; int32_t sect_2; int32_t sect_3; + bool src2; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a61ea8fb..784b7b77 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2969,6 +2969,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { /* sect_1 =*/ sect_1, /* sect_2 =*/ sect_2, /* sect_3 =*/ sect_3, + /* src2 =*/ op->src[2] != nullptr, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 1029cf8f..6d39ddcc 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3748,7 +3748,7 @@ kernel void kernel_rope_norm( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3801,7 +3801,7 @@ kernel void kernel_rope_neox( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3872,7 +3872,7 @@ kernel void kernel_rope_multi( const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); @@ -3939,7 +3939,7 @@ kernel void kernel_rope_vision( const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); // end of mrope - const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); From 0c9d49927c3e90949e4c9db5f44583ad4ba7660a Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Wed, 15 Oct 2025 13:54:15 +0200 Subject: [PATCH 077/104] CUDA: Changing the CUDA scheduling strategy to spin (llama/16585) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA set scheduling strategy to spinning for cc121 * Using prop.major and prop.minor, include HIP and MUSA * Exclude HIP and MUSA * Remove trailing whitespace Co-authored-by: Johannes Gäßler * Remove empty line Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index da312992..a5e77672 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() { } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { turing_devices_without_mma.push_back({ id, device_name }); } + + // Temporary performance fix: + // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls. + // TODO: Check for future drivers the default scheduling strategy and + // remove this call again when cudaDeviceScheduleSpin is default. + if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); + } + #endif // defined(GGML_USE_HIP) } From d8a146b0f9a1af396e1812e3fc6859483752dab1 Mon Sep 17 00:00:00 2001 From: Sam/Samuel <57896620+cern1710@users.noreply.github.com> Date: Wed, 15 Oct 2025 23:05:56 +0900 Subject: [PATCH 078/104] metal: optimise `GGML_OP_SUM` (llama/16559) * optimise GGML_OP_SUM * add non-contiguous tests by permuting the input * change tests to require full contiguity of OP_SUM * cuda : add check GGML_OP_SUM --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++- ggml/src/ggml-metal/ggml-metal.metal | 42 +++++++++++++++++++++---- 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a5e77672..75fd6db1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3625,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_SUM: case GGML_OP_ACC: return true; + case GGML_OP_SUM: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: // TODO: Support arbitrary column width return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 553cf8f5..c3c83abe 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 784b7b77..4f9f6bda 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + int nth = 32; // SIMD width + + while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int) n); + + const int nsg = (nth + 31) / 32; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6d39ddcc..496610b1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, device float * dst, - ushort tiitg[[thread_index_in_threadgroup]]) { + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - if (tiitg != 0) { + if (args.np == 0) { return; } - float acc = 0.0f; - for (ulong i = 0; i < args.np; ++i) { - acc += src0[i]; + const uint nsg = (ntg.x + 31) / 32; + + float sumf = 0; + + for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + sumf += src0[i0]; } - dst[0] = acc; + sumf = simd_sum(sumf); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float total = 0; + + if (sgitg == 0) { + float v = 0; + + if (tpitg.x < nsg) { + v = shmem_f32[tpitg.x]; + } + + total = simd_sum(v); + + if (tpitg.x == 0) { + dst[0] = total; + } + } } template From 16dab3d122232fc09d2c05a9ed7732f429164c6a Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 15 Oct 2025 10:48:28 -0700 Subject: [PATCH 079/104] opencl: fix FA for f32 (llama/16584) --- ggml/src/ggml-opencl/kernels/flash_attn_f32.cl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index 9c0bab13..a6d74790 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -4,6 +4,7 @@ #define ACC_TYPE4 float4 #define DATA_TYPE float #define DATA_TYPE4 float4 +#define MASK_DATA_TYPE half #define CONVERT_ACC4(x) (x) #define CONVERT_DATA4(x) (x) @@ -148,7 +149,7 @@ __kernel void flash_attn_f32( if (k_row1 >= n_kv) score1 = -INFINITY; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; } @@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) { @@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1( } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { - const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); score += slope * (ACC_TYPE)mask_ptr[k_idx]; } if (logit_softcap > 0.0f) { From bef9f74553e4dde2e1ac19f116b41a11bc4ce283 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 15 Oct 2025 10:51:04 -0700 Subject: [PATCH 080/104] opencl: add q8_0 mm support (llama/16469) * opencl: add mm_q8_0_f32 * opencl: fix data loading for incomplete tile * opencl: use q8_0 mm for larger matrix * opencl: add some tests to cover the path --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 56 +++++++ .../kernels/mul_mm_f16_f32_l4_lm.cl | 32 +++- .../kernels/mul_mm_f32_f32_l4_lm.cl | 34 ++-- .../kernels/mul_mm_q8_0_f32_l4_lm.cl | 154 ++++++++++++++++++ 5 files changed, 258 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7e6c8438..6f6bba55 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q8_0_f32_l4_lm mul norm relu diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0693d38d..2ec896fd 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -408,6 +408,7 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_id_mxfp4_f32_flat; cl_program program_mul_mm_f32_f32_l4_lm; cl_program program_mul_mm_f16_f32_l4_lm; + cl_program program_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16; cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16; @@ -480,6 +481,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; std::vector profiling_info; @@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q8_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q8_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl"); +#endif + backend_ctx->program_mul_mm_q8_0_f32_l4_lm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q8_0: { + if (ne11 < 32) { + break; + } + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl index 9599a0e1..1a1bfe14 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl index 58c5178e..39a5d486 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl @@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm( for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { - const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; - buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; - buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; - buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; - buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + if (loadc_a + l < ne01) { + const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } } for (int l = 0; l < BN; l += loadstride_b) { - const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; - buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; - buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; - buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; - buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + if (loadc_b + l < ne11) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } } barrier(CLK_LOCAL_MEM_FENCE); diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl new file mode 100644 index 00000000..fd47e8a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl @@ -0,0 +1,154 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q8_0_f32_l4_lm( + global char4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 8; + int iqs = idx % 8; + + float d = (float)src0_d[ib]; + global char4 * qs = src0_q + ib*8 + iqs; + char4 q = *qs; + float4 v = convert_float4(q)*d; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} From 757d51d21dc82c108477079d536999b74fcf74d5 Mon Sep 17 00:00:00 2001 From: safranowith Date: Wed, 15 Oct 2025 22:24:51 +0300 Subject: [PATCH 081/104] cpu : add FLOOR, CEIL, ROUND and TRUNC unary operators (llama/16083) * CPU: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators - Added the operators to unary op enum - Implemented API functions - Implemented forward and unary-op logic in CPU backend - Updated ggml_get_n_tasks - Updated operators names array and static_assert - Updated docs and enabled automatic tests * docs: add documentation for ggml_trunc and ggml_trunc_inplace in ggml.h * chore: remove trailing whitespace from ggml.h * Remove unresolved merge markers * Apply review suggestions: cleanup formatting, enum order and leftover artifacts * Regenerate ops.md using create_ops_docs.py --- ggml/include/ggml.h | 44 +++++++++++++++++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 4 +++ ggml/src/ggml-cpu/ops.cpp | 16 +++++++++ ggml/src/ggml-cpu/unary-ops.cpp | 32 +++++++++++++++++ ggml/src/ggml-cpu/unary-ops.h | 4 +++ ggml/src/ggml.c | 62 ++++++++++++++++++++++++++++++++- 6 files changed, 161 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 60c6b63d..d948b00c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -577,6 +577,10 @@ extern "C" { GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_XIELU, + GGML_UNARY_OP_FLOOR, + GGML_UNARY_OP_CEIL, + GGML_UNARY_OP_ROUND, + GGML_UNARY_OP_TRUNC, GGML_UNARY_OP_COUNT, }; @@ -1151,6 +1155,46 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + /** + * Truncates the fractional part of each element in the tensor (towards zero). + * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0 + * Similar to std::trunc in C/C++. + */ + + GGML_API struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + + // xIELU activation function // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ba2a36d9..29c87060 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 1c43865f..b52f0f84 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_FLOOR: + { + ggml_compute_forward_floor(params, dst); + } break; + case GGML_UNARY_OP_CEIL: + { + ggml_compute_forward_ceil(params, dst); + } break; + case GGML_UNARY_OP_ROUND: + { + ggml_compute_forward_round(params, dst); + } break; + case GGML_UNARY_OP_TRUNC: + { + ggml_compute_forward_trunc(params, dst); + } break; case GGML_UNARY_OP_XIELU: { ggml_compute_forward_xielu(params, dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index cf1a4615..a047537b 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -73,6 +73,22 @@ static inline float op_log(float x) { return logf(x); } +static inline float op_floor(float x) { + return floorf(x); +} + +static inline float op_ceil(float x) { + return ceilf(x); +} + +static inline float op_round(float x) { + return roundf(x); +} + +static inline float op_trunc(float x) { + return truncf(x); +} + template static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; @@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { const float alpha_n = ggml_get_op_params_f32(dst, 1); const float alpha_p = ggml_get_op_params_f32(dst, 2); diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index 697c1e0d..fa45d9f0 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bce1375..86f1c31a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1144,9 +1144,13 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "EXP", "GELU_ERF", "XIELU", + "FLOOR", + "CEIL", + "ROUND", + "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); +static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2749,6 +2753,62 @@ static struct ggml_tensor * ggml_glu_impl( return result; } +// ggml_floor + +struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR); +} + +struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR); +} + +// ggml_ceil + +struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL); +} + +struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL); +} + +//ggml_round + +struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND); +} + +struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND); +} + +//ggml_trunc + +struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC); +} + +struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC); +} + struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, From f7b5ecf195f3d6c1bc033356e6bdf41458b3981d Mon Sep 17 00:00:00 2001 From: yael-works <106673277+yael-works@users.noreply.github.com> Date: Thu, 16 Oct 2025 07:21:28 +0300 Subject: [PATCH 082/104] SYCL: Add GGML_OP_MEAN operator support (llama/16009) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: Add GGML_OP_MEAN operator support * SYCL: Fix formatting for GGML_OP_MEAN case * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-sycl/ggml-sycl.cpp | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 45b8c216..f3407a81 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } +inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + const int64_t ncols = dst->src[0]->ne[0]; + const int64_t nrows = ggml_nrows(dst->src[0]); + + sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); + + main_stream->parallel_for( + sycl::range<1>(nrows), + [=](sycl::id<1> row) { + dst_dd[row] /= ncols; + } + ); +} + + inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_I32); @@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds ggml_sycl_op_sum_rows(ctx, dst); } +static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_mean(ctx, dst); +} + static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); @@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_sycl_sum_rows(ctx, dst); break; + case GGML_OP_MEAN: + ggml_sycl_mean(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; @@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); case GGML_OP_POOL_2D: From 3c136d699a21b10e290d5a302a18e18dfad2c3cb Mon Sep 17 00:00:00 2001 From: takuya kodama Date: Thu, 16 Oct 2025 13:10:32 +0800 Subject: [PATCH 083/104] ggml-cpu: replace putenv with setenv for const-correctness (llama/16573) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why it failed When compiling with strict compiler flags (-Wwrite-strings -Werror=discarded-qualifiers), the build fails with the following error: ``` cmake \ -S . \ -B ../llama.cpp.build \ --preset=x64-linux-gcc-debug \ -DCMAKE_INSTALL_PREFIX=/tmp/local \ -DCMAKE_C_FLAGS="-Wwrite-strings -Werror=discarded-qualifiers" && \ cmake --build ../llama.cpp.build/ ... /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c: In function ‘ggml_cpu_init’: /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c:3572:24: error: passing argument 1 of ‘putenv’ discards ‘const’ qualifier from pointer target type [-Werror=discarded-qualifiers] 3572 | putenv("KMP_BLOCKTIME=200"); // 200ms | ^~~~~~~~~~~~~~~~~~~ In file included from /home/otegami/work/cpp/llama.cpp/ggml/src/./ggml-impl.h:10, from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h:6, from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h:3, from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c:6: /usr/include/stdlib.h:786:26: note: expected ‘char *’ but argument is of type ‘const char *’ 786 | extern int putenv (char *__string) __THROW __nonnull ((1)); | ~~~~~~^~~~~~~~ cc1: some warnings being treated as errors ninja: build stopped: subcommand failed. ``` The issue is that putenv() expects a non-const char * but receives a string literal (const char *). ## How to fix This PR replaces putenv("KMP_BLOCKTIME=200") with setenv("KMP_BLOCKTIME", "200", 0). Benefits of setenv(): - Accepts const char * parameters (no qualifier warnings) - Makes copies of the strings (safer memory handling) - The third parameter (0) ensures we don't overwrite if already set --- ggml/src/ggml-cpu/ggml-cpu.c | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 29c87060..9ec485cf 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3567,13 +3567,17 @@ void ggml_cpu_init(void) { #ifdef GGML_USE_OPENMP //if (!getenv("OMP_WAIT_POLICY")) { // // set the wait policy to active, so that OpenMP threads don't sleep - // putenv("OMP_WAIT_POLICY=active"); + // setenv("OMP_WAIT_POLICY", "active", 0) //} if (!getenv("KMP_BLOCKTIME")) { // set the time to wait before sleeping a thread // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases - putenv("KMP_BLOCKTIME=200"); // 200ms +#ifdef _WIN32 + _putenv_s("KMP_BLOCKTIME", "200"); // 200ms +#else + setenv("KMP_BLOCKTIME", "200", 0); // 200ms +#endif } #endif } From fe965613c030f007dd8f8682b11ed19585c483e6 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 16 Oct 2025 16:41:11 +0800 Subject: [PATCH 084/104] CANN: format code using .clang-format (llama/15863) This commit applies .clang-format rules to all source files under the ggml-cann directory to ensure consistent coding style and readability. The .clang-format option `SortIncludes: false` has been set to disable automatic reordering of include directives. No functional changes are introduced. Co-authored-by: hipudding --- ggml/src/ggml-cann/acl_tensor.cpp | 89 +- ggml/src/ggml-cann/acl_tensor.h | 97 +- ggml/src/ggml-cann/aclnn_ops.cpp | 2508 ++++++++++++++--------------- ggml/src/ggml-cann/aclnn_ops.h | 401 +++-- ggml/src/ggml-cann/common.h | 191 ++- ggml/src/ggml-cann/ggml-cann.cpp | 1109 ++++++------- 6 files changed, 2063 insertions(+), 2332 deletions(-) mode change 100755 => 100644 ggml/src/ggml-cann/acl_tensor.cpp mode change 100755 => 100644 ggml/src/ggml-cann/acl_tensor.h mode change 100755 => 100644 ggml/src/ggml-cann/aclnn_ops.cpp mode change 100755 => 100644 ggml/src/ggml-cann/aclnn_ops.h mode change 100755 => 100644 ggml/src/ggml-cann/common.h mode change 100755 => 100644 ggml/src/ggml-cann/ggml-cann.cpp diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp old mode 100755 new mode 100644 index 8ffac31d..8958ebcd --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -51,28 +51,31 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_DT_UNDEFINED; } -aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, - size_t* nb, int64_t dims, aclFormat format, - size_t offset) { +aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor, + int64_t * ne, + size_t * nb, + int64_t dims, + aclFormat format, + size_t offset) { // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be // added. int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2]; if (ne == nullptr) { for (int i = 0; i < GGML_MAX_DIMS; i++) { - acl_ne[i] = tensor->ne[i]; + acl_ne[i] = tensor->ne[i]; // The step size of acl is in elements. acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor); } } else { // With bcast for (int i = 0; i < dims; i++) { - acl_ne[i] = ne[i]; + acl_ne[i] = ne[i]; acl_stride[i] = nb[i] / ggml_element_size(tensor); } } - int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims); + int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims); int64_t acl_storage_len = 1; for (int i = 0; i < final_dims; i++) { acl_storage_len += (acl_ne[i] - 1) * acl_stride[i]; @@ -84,15 +87,13 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, std::reverse(acl_ne, acl_ne + final_dims); std::reverse(acl_stride, acl_stride + final_dims); - aclTensor* acl_tensor = aclCreateTensor( - acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, - elem_offset, format, &acl_storage_len, 1, - tensor->data); + aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, + elem_offset, format, &acl_storage_len, 1, tensor->data); return acl_tensor; } -bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) { +bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) { for (int i = 0; i < GGML_MAX_DIMS; i++) { if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) { return true; @@ -101,15 +102,16 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) { return false; } -int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, - const ggml_tensor* src1, - int64_t* bcast_src0_ne, - int64_t* bcast_src1_ne, size_t* bcast_src0_nb, - size_t* bcast_src1_nb) { +int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0, + const ggml_tensor * src1, + int64_t * bcast_src0_ne, + int64_t * bcast_src1_ne, + size_t * bcast_src0_nb, + size_t * bcast_src1_nb) { GGML_ASSERT(ggml_can_repeat(src1, src0)); int bcast_dim_cnt = 0; for (int i = 0; i < GGML_MAX_DIMS; i++) { - int64_t nr = src0->ne[i] / src1->ne[i]; + int64_t nr = src0->ne[i] / src1->ne[i]; bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr; bcast_src1_ne[bcast_dim_cnt] = src1->ne[i]; bcast_src0_nb[bcast_dim_cnt] = src0->nb[i]; @@ -119,21 +121,26 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, // Need to add an extra dim. bcast_src0_ne[bcast_dim_cnt] = nr; bcast_src1_ne[bcast_dim_cnt] = 1; - bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * - bcast_src0_ne[bcast_dim_cnt - 1]; - bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * - bcast_src1_ne[bcast_dim_cnt - 1]; + bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * bcast_src0_ne[bcast_dim_cnt - 1]; + bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * bcast_src1_ne[bcast_dim_cnt - 1]; bcast_dim_cnt++; } } return bcast_dim_cnt; } -int64_t ggml_cann_get_mulmat_bcast_shape( - const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne, - const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb, - int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne, - size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) { +int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne, + const int64_t * weight_ne, + const int64_t * dst_ne, + const size_t * input_nb, + const size_t * weight_nb, + const size_t * dst_nb, + int64_t * bcast_input_ne, + int64_t * bcast_weight_ne, + int64_t * bcast_dst_ne, + size_t * bcast_input_nb, + size_t * bcast_weight_nb, + size_t * bcast_dst_nb) { // input and dst shoule in same shape, except first two dims. GGML_ASSERT(input_ne[2] == dst_ne[2]); GGML_ASSERT(input_ne[3] == dst_ne[3]); @@ -148,34 +155,30 @@ int64_t ggml_cann_get_mulmat_bcast_shape( // Do not use bcast in the first two dimensions because we only support // the bcast batch dimension. Just copy them. if (i < 2 || nr == 1) { - bcast_input_ne[bcast_dim_cnt] = input_ne[i]; + bcast_input_ne[bcast_dim_cnt] = input_ne[i]; bcast_weight_ne[bcast_dim_cnt] = weight_ne[i]; - bcast_dst_ne[bcast_dim_cnt] = dst_ne[i]; + bcast_dst_ne[bcast_dim_cnt] = dst_ne[i]; - bcast_input_nb[bcast_dim_cnt] = input_nb[i]; + bcast_input_nb[bcast_dim_cnt] = input_nb[i]; bcast_weight_nb[bcast_dim_cnt] = weight_nb[i]; - bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; + bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; bcast_dim_cnt++; } else { // Need to add an extra dim. - bcast_input_ne[bcast_dim_cnt] = nr; - bcast_dst_ne[bcast_dim_cnt] = nr; + bcast_input_ne[bcast_dim_cnt] = nr; + bcast_dst_ne[bcast_dim_cnt] = nr; bcast_weight_ne[bcast_dim_cnt] = 1; - bcast_input_nb[bcast_dim_cnt] = input_nb[i]; - bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; + bcast_input_nb[bcast_dim_cnt] = input_nb[i]; + bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; bcast_weight_nb[bcast_dim_cnt] = weight_nb[i]; bcast_dim_cnt++; - bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr; - bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr; + bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr; + bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr; bcast_weight_ne[bcast_dim_cnt] = weight_ne[i]; - bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * - bcast_input_ne[bcast_dim_cnt - 1]; - bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * - bcast_dst_ne[bcast_dim_cnt - 1]; - bcast_weight_nb[bcast_dim_cnt] = - bcast_weight_nb[bcast_dim_cnt - 1] * - bcast_weight_ne[bcast_dim_cnt - 1]; + bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * bcast_input_ne[bcast_dim_cnt - 1]; + bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * bcast_dst_ne[bcast_dim_cnt - 1]; + bcast_weight_nb[bcast_dim_cnt] = bcast_weight_nb[bcast_dim_cnt - 1] * bcast_weight_ne[bcast_dim_cnt - 1]; bcast_dim_cnt++; } } diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h old mode 100755 new mode 100644 index 93f09937..cb17ebcc --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -62,10 +62,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type); * @param offset Offset in bytes for the ACL tensor data. Defaults to 0. * @return Pointer to the created ACL tensor. */ -aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr, - size_t* nb = nullptr, int64_t dims = 0, - aclFormat format = ACL_FORMAT_ND, - size_t offset = 0); +aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor, + int64_t * ne = nullptr, + size_t * nb = nullptr, + int64_t dims = 0, + aclFormat format = ACL_FORMAT_ND, + size_t offset = 0); /** * @brief Template for creating an ACL tensor from provided parameters. typename TYPE @@ -87,12 +89,15 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null * @param offset Offset in bytes for the ACL tensor data. Defaults to 0. * @return Pointer to the created ACL tensor. */ -template -aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, - TYPE type_size, int64_t* ne, TYPE* nb, - int64_t dims, - aclFormat format = ACL_FORMAT_ND, - size_t offset = 0) { +template +aclTensor * ggml_cann_create_tensor(void * data_ptr, + aclDataType dtype, + TYPE type_size, + int64_t * ne, + TYPE * nb, + int64_t dims, + aclFormat format = ACL_FORMAT_ND, + size_t offset = 0) { int64_t tmp_ne[GGML_MAX_DIMS * 2]; int64_t tmp_stride[GGML_MAX_DIMS * 2]; @@ -109,9 +114,8 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, std::reverse(tmp_ne, tmp_ne + dims); std::reverse(tmp_stride, tmp_stride + dims); - aclTensor* acl_tensor = - aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, - format, &acl_storage_len, 1, data_ptr); + aclTensor * acl_tensor = + aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr); return acl_tensor; } @@ -132,7 +136,7 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, * to 1. If such a dimension is found, broadcasting is required to align t1 * with t0 for element-wise operations. */ -bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1); +bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1); /** * @brief Computes broadcast shapes and strides for two ggml_tensors. @@ -187,19 +191,21 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1); * dim1 in a inserted dim, should add nb for dim1, * and all other nb moves to next in order. */ -int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1, - int64_t* bcast_ne_src0, int64_t* bcast_ne_src1, - size_t* bcast_nb_src0, size_t* bcast_nb_src1); +int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0, + const ggml_tensor * src1, + int64_t * bcast_ne_src0, + int64_t * bcast_ne_src1, + size_t * bcast_nb_src0, + size_t * bcast_nb_src1); // Bcast macro to avoid duplicate code. -#define BCAST_SHAPE(src0, src1) \ - int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \ - int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \ - size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \ - size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \ - int64_t bcast_dims = ggml_cann_get_bcast_shape( \ - src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \ - bcast_##src1##_nb); +#define BCAST_SHAPE(src0, src1) \ + int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \ + int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \ + size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \ + size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \ + int64_t bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \ + bcast_##src0##_nb, bcast_##src1##_nb); #define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims @@ -233,26 +239,31 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr * before cast dim. * @sa ggml_cann_get_bcast_shape */ -int64_t ggml_cann_get_mulmat_bcast_shape( - const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne, - const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb, - int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne, - size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb); +int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne, + const int64_t * weight_ne, + const int64_t * dst_ne, + const size_t * input_nb, + const size_t * weight_nb, + const size_t * dst_nb, + int64_t * bcast_input_ne, + int64_t * bcast_weight_ne, + int64_t * bcast_dst_ne, + size_t * bcast_input_nb, + size_t * bcast_weight_nb, + size_t * bcast_dst_nb); // Bcast macro to avoid duplicate code. -#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \ - int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \ - int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \ - int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \ - size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \ - size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \ - size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \ - int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \ - input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \ - bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \ - bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb); +#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \ + int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \ + int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \ + int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \ + size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \ + size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \ + size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \ + int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \ + input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \ + bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb); -#define BCAST_MUL_MAT_PARAM(tensor) \ - bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims +#define BCAST_MUL_MAT_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims #endif // CANN_ACL_TENSOR_H diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100755 new mode 100644 index 2857e080..f030ea01 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -86,9 +86,12 @@ #include "../ggml-common.h" - -void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, - aclTensor ** acl_src1, aclTensor ** acl_dst) { +void bcast_shape(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + aclTensor ** acl_src0, + aclTensor ** acl_src1, + aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); // Need bcast if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { @@ -103,40 +106,40 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT } } -void ggml_cann_op_unary( - std::function unary_op, - ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_op_unary(std::function unary_op, + ggml_backend_cann_context & ctx, + ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); unary_op(ctx, acl_src, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst); } -void ggml_cann_op_unary_gated( - std::function unary_op, - ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; +void ggml_cann_op_unary_gated(std::function unary_op, + ggml_backend_cann_context & ctx, + ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); const int32_t swapped = ggml_get_op_params_i32(dst, 1); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); - aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr; - if(src1) { + aclTensor * acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src0 = nullptr, *acl_src1 = nullptr; + if (src1) { GGML_ASSERT(ggml_is_contiguous_1(src1)); GGML_ASSERT(src0->type == src1->type); acl_src0 = ggml_cann_create_tensor(src0); acl_src1 = ggml_cann_create_tensor(src1); } else { - int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]}; - size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]}; - acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0); + int64_t ne[] = { src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3] }; + size_t nb[] = { src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3] }; + acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0); acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0)); if (swapped) { std::swap(acl_src0, acl_src1); @@ -159,10 +162,12 @@ void ggml_cann_op_unary_gated( * @param repeat_array The array specifying the number of repetitions along each * dimension. */ -static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, int64_t* repeat_array) { +static void aclnn_repeat(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + int64_t * repeat_array) { // repeat tensor along each dim with repeat_array - aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS); + aclIntArray * repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS); GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst); ggml_cann_release_resources(ctx, repeats); @@ -181,61 +186,63 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param cast_data_type The target data type to which the source tensor will be * casted. */ -static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, aclDataType cast_data_type) { +static void aclnn_cast(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + aclDataType cast_data_type) { GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst); } -void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; GGML_ASSERT(ggml_can_repeat(src, dst)); - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2], - dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]}; + int64_t repeatsArray[] = { dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2], dst->ne[1] / src->ne[1], + dst->ne[0] / src->ne[0] }; aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray); ggml_cann_release_resources(ctx, acl_src, acl_dst); } -void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, - aclTensor* acl_src1, aclTensor* acl_dst) { - float alphaValue = 1.0f; - aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - if (acl_dst != nullptr) +void aclnn_add(ggml_backend_cann_context & ctx, aclTensor * acl_src0, aclTensor * acl_src1, aclTensor * acl_dst) { + float alphaValue = 1.0f; + aclScalar * alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); - else + } else { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha); + } ggml_cann_release_resources(ctx, alpha); } -void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, - aclTensor* acl_src1, aclTensor* acl_dst) { - float alphaValue = 1.0f; - aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - if (acl_dst != nullptr) +void aclnn_sub(ggml_backend_cann_context & ctx, aclTensor * acl_src0, aclTensor * acl_src1, aclTensor * acl_dst) { + float alphaValue = 1.0f; + aclScalar * alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst); - else + } else { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha); + } ggml_cann_release_resources(ctx, alpha); } -void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst) { - if (acl_dst != nullptr) +void aclnn_mul(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_other, aclTensor * acl_dst) { + if (acl_dst != nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst); - else + } else { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other); + } } -void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst) { - if (acl_dst != nullptr) +void aclnn_div(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_other, aclTensor * acl_dst) { + if (acl_dst != nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst); - else + } else { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other); + } } /** @@ -260,9 +267,12 @@ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param inplace Flag indicating whether to perform the operation in-place on * `acl_src`. */ -static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, - float scale, aclTensor* acl_dst, bool inplace) { - aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); +static void aclnn_muls(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + float scale, + aclTensor * acl_dst, + bool inplace) { + aclScalar * acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); if (inplace) { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale); } else { @@ -271,19 +281,18 @@ static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, ggml_cann_release_resources(ctx, acl_scale); } -void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - aclScalar* acl_negative_slope = - aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT); + aclScalar * acl_negative_slope = aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst); ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst); @@ -299,26 +308,27 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * stored. * @param concat_dim The dimension along which the tensors will be concatenated. */ -static void aclnn_concat(ggml_backend_cann_context& ctx, - aclTensorList* tensorList, aclTensor* acl_dst, - int64_t concat_dim) { +static void aclnn_concat(ggml_backend_cann_context & ctx, + aclTensorList * tensorList, + aclTensor * acl_dst, + int64_t concat_dim) { GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst); } -void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - aclTensor* acl_src1 = ggml_cann_create_tensor(src1); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); +void ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + aclTensor * acl_src0 = ggml_cann_create_tensor(src0); + aclTensor * acl_src1 = ggml_cann_create_tensor(src1); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); const int32_t dim = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(dim >= 0 && dim < 4); int32_t acl_dim = 3 - dim; - aclTensor* tensors[] = {acl_src0, acl_src1}; - aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + aclTensor * tensors[] = { acl_src0, acl_src1 }; + aclTensorList * tensor_list = aclCreateTensorList(tensors, 2); aclnn_concat(ctx, tensor_list, acl_dst, acl_dim); ggml_cann_release_resources(ctx, tensor_list, acl_dst); @@ -341,162 +351,157 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param step The step size between consecutive values. * @param n_elements The number of elements in the destination tensor. */ -static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, - float start, float stop, float step, - int64_t n_elements) { - int64_t steps = (int64_t)std::ceil((stop - start) / step); +static void aclnn_arange(ggml_backend_cann_context & ctx, + aclTensor * acl_dst, + float start, + float stop, + float step, + int64_t n_elements) { + int64_t steps = (int64_t) std::ceil((stop - start) / step); GGML_ASSERT(n_elements == steps); - aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT); - aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT); - aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT); + aclScalar * acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT); + aclScalar * acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT); + aclScalar * acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst); ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step); } -void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +void ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->type == GGML_TYPE_F32); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); int64_t n_elements = ggml_nelements(dst); - float start; - float stop; - float step; - memcpy(&start, (float*)dst->op_params + 0, sizeof(float)); - memcpy(&stop, (float*)dst->op_params + 1, sizeof(float)); - memcpy(&step, (float*)dst->op_params + 2, sizeof(float)); + float start; + float stop; + float step; + memcpy(&start, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&stop, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&step, (float *) dst->op_params + 2, sizeof(float)); aclnn_arange(ctx, acl_dst, start, stop, step, n_elements); ggml_cann_release_resources(ctx, acl_dst); } -void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; float min; float max; memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT); - aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT); + aclScalar * acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT); + aclScalar * acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst); ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst); } -void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; // scale factor float v; memcpy(&v, dst->op_params, sizeof(float)); - aclScalar* scale = aclCreateScalar(&v, aclDataType::ACL_FLOAT); - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclScalar * scale = aclCreateScalar(&v, aclDataType::ACL_FLOAT); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst); ggml_cann_release_resources(ctx, scale, acl_src, acl_dst); } -void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - enum ggml_sort_order order = (enum ggml_sort_order)dst->op_params[0]; +void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); - ggml_cann_pool_alloc temp_buffer_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(int64_t)); - void* buffer = temp_buffer_allocator.get(); - aclTensor* tmp_tensor = - ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), - dst->ne, dst->nb, GGML_MAX_DIMS); - GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), - tmp_tensor); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); + ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(int64_t)); + void * buffer = temp_buffer_allocator.get(); + aclTensor * tmp_tensor = + ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), dst->ne, dst->nb, GGML_MAX_DIMS); + GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor); GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst); ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst); } -void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - std::vector normData = {dst->ne[0]}; - aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size()); - GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr, - eps, acl_dst, nullptr, nullptr); + std::vector normData = { dst->ne[0] }; + aclIntArray * norm = aclCreateIntArray(normData.data(), normData.size()); + GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr, eps, acl_dst, nullptr, nullptr); ggml_cann_release_resources(ctx, norm, acl_src, acl_dst); } -void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); int n_groups = dst->op_params[0]; float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); - int64_t N = src->ne[3]; - int64_t C = src->ne[2]; + int64_t N = src->ne[3]; + int64_t C = src->ne[2]; int64_t HxW = src->ne[1] * src->ne[0]; - size_t type_size = ggml_type_size(src->type); - int64_t ne[] = {n_groups, N}; - size_t nb[] = {type_size, type_size * n_groups}; - size_t n_bytes = N * n_groups; + size_t type_size = ggml_type_size(src->type); + int64_t ne[] = { n_groups, N }; + size_t nb[] = { type_size, type_size * n_groups }; + size_t n_bytes = N * n_groups; ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes * 2); - void* buffer = temp_buffer_allocator.get(); - aclTensor* acl_mean_out = ggml_cann_create_tensor( - buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); - aclTensor* acl_rstd_out = ggml_cann_create_tensor( - (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); + void * buffer = temp_buffer_allocator.get(); + aclTensor * acl_mean_out = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); + aclTensor * acl_rstd_out = + ggml_cann_create_tensor((char *) buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, - acl_dst, acl_mean_out, acl_rstd_out); + GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, acl_mean_out, + acl_rstd_out); ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out); } -void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; +void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; - size_t nb1 = ((int32_t*)dst->op_params)[0]; - size_t nb2 = ((int32_t*)dst->op_params)[1]; - size_t nb3 = ((int32_t*)dst->op_params)[2]; - size_t offset = ((int32_t*)dst->op_params)[3]; - bool inplace = (bool)((int32_t*)dst->op_params)[4]; + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - size_t param_nb[] = {ggml_element_size(src0), nb1, nb2, nb3}; + size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 }; - aclTensor* acl_dst = ggml_cann_create_tensor( - dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); - aclTensor* acl_src1 = ggml_cann_create_tensor(src1); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + aclTensor * acl_src1 = ggml_cann_create_tensor(src1); - aclScalar* alpha = nullptr; - float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + aclScalar * alpha = nullptr; + float alphaValue = 1.0f; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); if (!inplace) { size_t cpy_size = ggml_nbytes(dst); - ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE); - aclTensor* acl_src0 = ggml_cann_create_tensor( - src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE); + aclTensor * acl_src0 = ggml_cann_create_tensor(src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); ggml_cann_release_resources(ctx, acl_src0); @@ -516,39 +521,34 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param dim An array of dimension indices. * @param dim_size The number of dimensions. */ -static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst, - int64_t* dim, size_t dim_size) { +static void aclnn_reduce_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t * dim, size_t dim_size) { GGML_ASSERT(dst->ne[0] == 1); - ggml_tensor* src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); - aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size); + ggml_tensor * src = dst->src[0]; + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); + aclIntArray * reduce_dims = aclCreateIntArray(dim, dim_size); - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true, - ggml_cann_type_mapping(dst->type), acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true, ggml_cann_type_mapping(dst->type), acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims); } -void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - int64_t reduce_dims[] = {3}; +void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + int64_t reduce_dims[] = { 3 }; aclnn_reduce_sum(ctx, dst, reduce_dims, 1); } -void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - int64_t reduce_dims[] = {0, 1, 2, 3}; +void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + int64_t reduce_dims[] = { 0, 1, 2, 3 }; aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } -void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, - ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - aclTensor* acl_src = - ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - aclTensor* acl_dst = - ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); +void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + aclTensor * acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - std::vector output_size{dst->ne[1], dst->ne[0]}; - auto output_size_array = aclCreateIntArray(output_size.data(), 2); + std::vector output_size{ dst->ne[1], dst->ne[0] }; + auto output_size_array = aclCreateIntArray(output_size.data(), 2); GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array); @@ -568,20 +568,22 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, * The size of the array should be twice the number of dimensions of the tensor. * @param value The value to be used for padding. The default value is 0.0. */ -static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, int64_t* paddings, - float value = 0.0f) { - aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); - aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); +static void aclnn_pad(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + int64_t * paddings, + float value = 0.0f) { + aclIntArray * acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); + aclScalar * acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst); ggml_cann_release_resources(ctx, acl_pad, acl_value); } -void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); +void ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); // padding: value in the array means how much distance will be padding. // the position of elements in the array means which dirction to padding, @@ -596,7 +598,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int32_t lp3 = ggml_get_op_params_i32(dst, 6); const int32_t rp3 = ggml_get_op_params_i32(dst, 7); - int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3}; + int64_t paddings[] = { lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3 }; aclnn_pad(ctx, acl_src, acl_dst, paddings); ggml_cann_release_resources(ctx, acl_src, acl_dst); } @@ -613,46 +615,41 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param dst The destination tensor where the result will be stored. The source * tensor is referenced by `dst->src[0]`. */ -static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, - ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +static void ggml_cann_avg_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - aclTensor* acl_src = - ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - aclTensor* acl_dst = - ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); + aclTensor * acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - const int32_t* opts = (const int32_t*)dst->op_params; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; + const int32_t * opts = (const int32_t *) dst->op_params; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; - std::vector kernel_dims = {k1, k0}; - std::vector stride_dims = {s1, s0}; - std::vector padding_avg_dims = {p1, p0}; // (padH, padW) + std::vector kernel_dims = { k1, k0 }; + std::vector stride_dims = { s1, s0 }; + std::vector padding_avg_dims = { p1, p0 }; // (padH, padW) - auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2); - auto* strides = aclCreateIntArray(stride_dims.data(), 2); - auto* paddings_avg = aclCreateIntArray(padding_avg_dims.data(), 2); + auto * kernel_size = aclCreateIntArray(kernel_dims.data(), 2); + auto * strides = aclCreateIntArray(stride_dims.data(), 2); + auto * paddings_avg = aclCreateIntArray(padding_avg_dims.data(), 2); - bool ceil_mode = false; - bool count_include_pad = true; - int64_t divisor_override = 0; - int8_t cube_math_type = 0; + bool ceil_mode = false; + bool count_include_pad = true; + int64_t divisor_override = 0; + int8_t cube_math_type = 0; #ifdef ASCEND_310P cube_math_type = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg, - ceil_mode, count_include_pad, divisor_override, - cube_math_type, acl_dst); - ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides, - paddings_avg); + GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg, ceil_mode, count_include_pad, + divisor_override, cube_math_type, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides, paddings_avg); } /** @@ -667,68 +664,61 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, * @param dst The destination tensor where the result will be stored. The source * tensor is referenced by `dst->src[0]`. */ -static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx, - ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +static void ggml_cann_max_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - aclTensor* acl_src = - ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - aclTensor* acl_dst = - ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); + aclTensor * acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW); - const int32_t* opts = (const int32_t*)dst->op_params; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; + const int32_t * opts = (const int32_t *) dst->op_params; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; - int64_t temp_ne[] = {src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2], - src->ne[3]}; - size_t temp_nb[GGML_MAX_DIMS]; + int64_t temp_ne[] = { src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2], src->ne[3] }; + size_t temp_nb[GGML_MAX_DIMS]; temp_nb[0] = ggml_element_size(src); for (int i = 1; i < GGML_MAX_DIMS; i++) { temp_nb[i] = temp_nb[i - 1] * temp_ne[i - 1]; } - ggml_cann_pool_alloc temp_buffer_allocator( - ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]); - void* buffer = temp_buffer_allocator.get(); - aclTensor* tmp_tensor = ggml_cann_create_tensor( - buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb, - GGML_MAX_DIMS, ACL_FORMAT_NCHW); + ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]); + void * buffer = temp_buffer_allocator.get(); + aclTensor * tmp_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb, + GGML_MAX_DIMS, ACL_FORMAT_NCHW); // pad: see padding in ggml_cann_pad() - int64_t paddings[] = {p0, p0, p1, p1, 0, 0, 0, 0}; - float value = -FLT_MAX; + int64_t paddings[] = { p0, p0, p1, p1, 0, 0, 0, 0 }; + float value = -FLT_MAX; aclnn_pad(ctx, acl_src, tmp_tensor, paddings, value); // max_pool - std::vector kernel_dims = {k1, k0}; - std::vector stride_dims = {s1, s0}; + std::vector kernel_dims = { k1, k0 }; + std::vector stride_dims = { s1, s0 }; // padding_max_dims: [dim0_start, dim0_end, dim1_start, dim1_end] - std::vector padding_max_dims = {0, 0, 0, 0}; - std::vector dilation_size = {1, 1}; - auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2); - auto* strides = aclCreateIntArray(stride_dims.data(), 2); - auto* paddings_max = aclCreateIntArray(padding_max_dims.data(), 4); - auto* dilations = aclCreateIntArray(dilation_size.data(), 2); + std::vector padding_max_dims = { 0, 0, 0, 0 }; + std::vector dilation_size = { 1, 1 }; + auto * kernel_size = aclCreateIntArray(kernel_dims.data(), 2); + auto * strides = aclCreateIntArray(stride_dims.data(), 2); + auto * paddings_max = aclCreateIntArray(padding_max_dims.data(), 4); + auto * dilations = aclCreateIntArray(dilation_size.data(), 2); - bool ceil_mode = false; + bool ceil_mode = false; int64_t auto_pads = 0; - GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads, - paddings_max, dilations, ceil_mode, acl_dst); - ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size, - strides, paddings_max, dilations); + GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations, + ceil_mode, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size, strides, paddings_max, dilations); } -void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - const int32_t* opts = (const int32_t*)dst->op_params; - enum ggml_op_pool op = static_cast(opts[0]); +void ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + const int32_t * opts = (const int32_t *) dst->op_params; + enum ggml_op_pool op = static_cast(opts[0]); switch (op) { case GGML_OP_POOL_AVG: ggml_cann_avg_pool2d(ctx, dst); @@ -752,17 +742,16 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param acl_src The source tensor from which data will be copied. * @param acl_dst The destination tensor where the data will be copied to. */ -static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst) { +static void cann_copy(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src); } -void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; +void ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; if (ggml_are_same_shape(src0, dst)) { - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); if (dst->type == src0->type) { cann_copy(ctx, acl_src, acl_dst); } else { @@ -770,22 +759,20 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } ggml_cann_release_resources(ctx, acl_src, acl_dst); } else { - void* src_trans_buffer = src0->data; + void * src_trans_buffer = src0->data; ggml_cann_pool_alloc src_buffer_allocator; if (!ggml_is_contiguous(src0)) { - aclTensor* acl_src = ggml_cann_create_tensor(src0); - src_buffer_allocator.alloc(ctx.pool(), - ggml_nelements(src0) * ggml_type_size(src0->type)); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + src_buffer_allocator.alloc(ctx.pool(), ggml_nelements(src0) * ggml_type_size(src0->type)); src_trans_buffer = src_buffer_allocator.get(); size_t src_trans_nb[GGML_MAX_DIMS]; src_trans_nb[0] = ggml_type_size(src0->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src_trans_nb, - GGML_MAX_DIMS); + aclTensor * src_trans_tensor = + ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); cann_copy(ctx, acl_src, src_trans_tensor); ggml_cann_release_resources(ctx, acl_src, src_trans_tensor); } @@ -796,10 +783,10 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1]; } - aclTensor* trans_acl_src = ggml_cann_create_tensor(src_trans_buffer, - ggml_cann_type_mapping(src0->type),ggml_type_size(src0->type), - dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * trans_acl_src = + ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); if (dst->type == src0->type) { cann_copy(ctx, trans_acl_src, acl_dst); @@ -827,17 +814,20 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param type_size The size of each element in the tensor data type. * @return An ACL tensor initialized with zeros. */ -static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, - size_t n_bytes, int64_t* ne, int64_t dims, - aclDataType type, size_t type_size) { +static aclTensor * aclnn_zero(ggml_backend_cann_context & ctx, + void * buffer, + size_t n_bytes, + int64_t * ne, + int64_t dims, + aclDataType type, + size_t type_size) { size_t nb[GGML_MAX_DIMS]; nb[0] = type_size; for (int i = 1; i < dims; i++) { nb[i] = nb[i - 1] * ne[i - 1]; } - aclTensor* zero = - ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); + aclTensor * zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero); return zero; GGML_UNUSED(n_bytes); @@ -861,15 +851,18 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, * is 1.0). * @return An ACL tensor initialized with value. */ -static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, - size_t n_bytes, int64_t* ne, int64_t dims, - aclDataType type, size_t type_size, - float value = 1.0f) { - aclTensor* acl_tensor = - aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size); - float alpha_host = 1.0f; - aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT); - aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT); +static aclTensor * aclnn_values(ggml_backend_cann_context & ctx, + void * buffer, + size_t n_bytes, + int64_t * ne, + int64_t dims, + aclDataType type, + size_t type_size, + float value = 1.0f) { + aclTensor * acl_tensor = aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size); + float alpha_host = 1.0f; + aclScalar * alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT); + aclScalar * other = aclCreateScalar(&value, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha); return acl_tensor; } @@ -884,8 +877,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, * @param scalar The scalar value used to fill the tensor. * @param acl_dst The destination tensor to be filled with the scalar value. */ -static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, - aclTensor* acl_dst) { +static void aclnn_fill_scalar(ggml_backend_cann_context & ctx, float scalar, aclTensor * acl_dst) { auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar); ggml_cann_release_resources(ctx, acl_scalar); @@ -913,15 +905,14 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, * initialization via memset or arbitrary values via fill_scalar). * @return An aclTensor pointer created from the cached buffer. */ -static aclTensor* get_cache_acl_tensor( - ggml_backend_cann_context& ctx, - void** buffer, - int64_t &cache_element, - int64_t* ne, - size_t* nb, - ggml_type dtype, - int64_t dims, - float value) { +static aclTensor * get_cache_acl_tensor(ggml_backend_cann_context & ctx, + void ** buffer, + int64_t & cache_element, + int64_t * ne, + size_t * nb, + ggml_type dtype, + int64_t dims, + float value) { // Calculate total number of elements int64_t n_element = 1; for (int i = 0; i < dims; i++) { @@ -940,24 +931,22 @@ static aclTensor* get_cache_acl_tensor( cache_element = n_element; // Initialize cache - int64_t pool_ne[1] = { n_element }; - size_t pool_nb[1] = { ggml_type_size(dtype) }; - aclTensor* acl_value = ggml_cann_create_tensor( - *buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), - pool_ne, pool_nb, 1); + int64_t pool_ne[1] = { n_element }; + size_t pool_nb[1] = { ggml_type_size(dtype) }; + aclTensor * acl_value = + ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), pool_ne, pool_nb, 1); aclnn_fill_scalar(ctx, value, acl_value); ggml_cann_release_resources(ctx, acl_value); } - return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), - ggml_type_size(dtype), ne, nb, dims); + return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype), ne, nb, dims); } -void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -969,61 +958,50 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { for (int i = 1; i < GGML_MAX_DIMS; i++) { acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; } - aclTensor* acl_gamma = get_cache_acl_tensor( - ctx, - &ctx.rms_norm_one_tensor_cache.cache, - ctx.rms_norm_one_tensor_cache.size, - src->ne, - acl_gamma_nb, - dst->type, - 1, // dims - 1.0f // value + aclTensor * acl_gamma = get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, + ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, dst->type, + 1, // dims + 1.0f // value ); // build rstd. - int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; - size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + int64_t acl_rstd_ne[] = { src->ne[1], src->ne[2], src->ne[3] }; + size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; // rstd will always be F32. acl_rstd_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; } - aclTensor* acl_rstd = get_cache_acl_tensor( - ctx, - &ctx.rms_norm_zero_tensor_cache.cache, - ctx.rms_norm_zero_tensor_cache.size, - acl_rstd_ne, - acl_rstd_nb, - GGML_TYPE_F32, - GGML_MAX_DIMS - 1, - 0.0f // value - ); + aclTensor * acl_rstd = + get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, + acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS - 1, + 0.0f // value + ); GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd); } // TODO: performace is low. -void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, - float value) { - ggml_tensor* src = dst->src[0]; +void ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - const int n_past = ((int32_t*)dst->op_params)[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src)); - void* buffer = one_tensor_allocator.get(); + void * buffer = one_tensor_allocator.get(); - aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type), - ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS); + aclTensor * mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type), + ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS); aclnn_fill_scalar(ctx, value, mask_tensor); - aclScalar* alpha = nullptr; - float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + aclScalar * alpha = nullptr; + float alphaValue = 1.0f; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1); GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst); @@ -1046,25 +1024,27 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, * tensor. * @param dims The number of dimensions in the tensor. */ -static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) { - aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); +static void aclnn_permute(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + int64_t * new_dim, + uint64_t dims) { + aclIntArray * acl_dims = aclCreateIntArray(new_dim, dims); GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst); ggml_cann_release_resources(ctx, acl_dims); } -static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, - ggml_tensor* dst, - ggml_tensor* src1, - aclTensor* tmp_cast_tensor, - aclTensor* tmp_im2col_tensor) { +static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + ggml_tensor * src1, + aclTensor * tmp_cast_tensor, + aclTensor * tmp_im2col_tensor) { // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW] - int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]}; - size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]}; - aclTensor* acl_dst = - ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1); + int64_t dst_ne[] = { dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3] }; + size_t dst_nb[] = { dst->nb[0], dst->nb[1], dst->nb[3] }; + aclTensor * acl_dst = ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1); - int64_t permute_dim[] = {0, 2, 1}; + int64_t permute_dim[] = { 0, 2, 1 }; if (src1->type != dst->type) { aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3); } else { @@ -1074,101 +1054,95 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, ggml_cann_release_resources(ctx, acl_dst); } -static void ggml_cann_im2col_1d_post_process( - ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1, - aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor, - const std::vector& im2col_op_params) { +static void ggml_cann_im2col_1d_post_process(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + ggml_tensor * src1, + aclTensor * tmp_cast_tensor, + aclTensor * tmp_im2col_tensor, + const std::vector & im2col_op_params) { // get params - const int64_t KH = im2col_op_params[0]; - const int64_t KW = im2col_op_params[1]; - const int64_t IW = im2col_op_params[2]; - const int64_t IC = im2col_op_params[3]; - const int64_t N = im2col_op_params[4]; - const int64_t OH = im2col_op_params[5]; - const int64_t OW = im2col_op_params[6]; - const int64_t s0 = im2col_op_params[7]; - const int64_t p0 = im2col_op_params[8]; - const int64_t d0 = im2col_op_params[9]; + const int64_t KH = im2col_op_params[0]; + const int64_t KW = im2col_op_params[1]; + const int64_t IW = im2col_op_params[2]; + const int64_t IC = im2col_op_params[3]; + const int64_t N = im2col_op_params[4]; + const int64_t OH = im2col_op_params[5]; + const int64_t OW = im2col_op_params[6]; + const int64_t s0 = im2col_op_params[7]; + const int64_t p0 = im2col_op_params[8]; + const int64_t d0 = im2col_op_params[9]; const int64_t n_bytes_factor = im2col_op_params[10]; // Permute: [N, IC * KH * KW, OW * OH] -> // [N, OW * OH * n_bytes_factor, IC * KH * KW] ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool()); tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor); - void* tmp_permute_buffer = tmp_permute_allocator.get(); + void * tmp_permute_buffer = tmp_permute_allocator.get(); - int64_t tmp_permute_ne[] = {IC * KH * KW, OW * OH * n_bytes_factor, N}; - size_t tmp_permute_nb[GGML_MAX_DIMS - 1]; + int64_t tmp_permute_ne[] = { IC * KH * KW, OW * OH * n_bytes_factor, N }; + size_t tmp_permute_nb[GGML_MAX_DIMS - 1]; tmp_permute_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; } - aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( - tmp_permute_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb, - GGML_MAX_DIMS - 1, ACL_FORMAT_ND); + aclTensor * tmp_permute_tensor = + ggml_cann_create_tensor(tmp_permute_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 2, 1}; + int64_t permute_dim[] = { 0, 2, 1 }; if (src1->type != dst->type) { aclnn_permute(ctx, tmp_cast_tensor, tmp_permute_tensor, permute_dim, 3); } else { - aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim, - 3); + aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim, 3); } // number of times the kernel moves in W dimension const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1) - 1) / s0 + 1; - size_t offset; - void *cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer; + size_t offset; + void * cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer; // memory copy with offset to restore 1D im2col from 2d if (IC > 1) { - offset = IC * KH * KW * n_step_w * ggml_type_size(dst->type); + offset = IC * KH * KW * n_step_w * ggml_type_size(dst->type); size_t size_cpy = KH * KW * ggml_type_size(dst->type); for (int c = 0; c < IC; c++) { - cur_permute_buffer = (char*)tmp_permute_buffer + offset + - KH * KW * c * ggml_type_size(dst->type); - cur_dst_buffer = (char*)dst->data + - c * KH * KW * n_step_w * ggml_type_size(dst->type); + cur_permute_buffer = (char *) tmp_permute_buffer + offset + KH * KW * c * ggml_type_size(dst->type); + cur_dst_buffer = (char *) dst->data + c * KH * KW * n_step_w * ggml_type_size(dst->type); for (int i = 0; i < n_step_w; i++) { - ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy, - ACL_MEMCPY_DEVICE_TO_DEVICE); - cur_dst_buffer = - (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type); - cur_permute_buffer = (char*)cur_permute_buffer + - KH * KW * IC * ggml_type_size(dst->type); + ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy, ACL_MEMCPY_DEVICE_TO_DEVICE); + cur_dst_buffer = (char *) cur_dst_buffer + KH * KW * ggml_type_size(dst->type); + cur_permute_buffer = (char *) cur_permute_buffer + KH * KW * IC * ggml_type_size(dst->type); } } } else { - offset = KH * KW * n_step_w * - ggml_type_size(dst->type); // equal to ggml_nbytes(dst) - ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset, - ACL_MEMCPY_DEVICE_TO_DEVICE); + offset = KH * KW * n_step_w * ggml_type_size(dst->type); // equal to ggml_nbytes(dst) + ggml_cann_async_memcpy(ctx, dst->data, (char *) tmp_permute_buffer + offset, offset, + ACL_MEMCPY_DEVICE_TO_DEVICE); } ggml_cann_release_resources(ctx, tmp_permute_tensor); } -void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; // kernel - ggml_tensor* src1 = dst->src[1]; // input +void ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // kernel + ggml_tensor * src1 = dst->src[1]; // input GGML_TENSOR_BINARY_OP_LOCALS; // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D // im2col and do post-processing to restore it to 1D. - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = is_2D ? ((const int32_t*)(dst->op_params))[1] : 1; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = is_2D ? ((const int32_t*)(dst->op_params))[3] : 1; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = is_2D ? ((const int32_t*)(dst->op_params))[5] : 1; + const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; + const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t s1 = is_2D ? ((const int32_t *) (dst->op_params))[1] : 1; + const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; + const int32_t p1 = is_2D ? ((const int32_t *) (dst->op_params))[3] : 1; + const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; + const int32_t d1 = is_2D ? ((const int32_t *) (dst->op_params))[5] : 1; - const int64_t N = ne13; + const int64_t N = ne13; const int64_t IC = ne12; const int64_t KH = ne01; const int64_t KW = ne00; @@ -1181,9 +1155,9 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int64_t n_bytes_factor = is_2D ? 1 : 3; // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor] - aclTensor* acl_src1 = ggml_cann_create_tensor(src1); - int64_t tmp_im2col_ne[] = {OW * OH * n_bytes_factor, IC * KH * KW, N}; - size_t tmp_im2col_nb[GGML_MAX_DIMS - 1]; + aclTensor * acl_src1 = ggml_cann_create_tensor(src1); + int64_t tmp_im2col_ne[] = { OW * OH * n_bytes_factor, IC * KH * KW, N }; + size_t tmp_im2col_nb[GGML_MAX_DIMS - 1]; tmp_im2col_nb[0] = ggml_type_size(src1->type); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { @@ -1193,31 +1167,27 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // Calculate im2col. // If dst is f16, tmp_buffer is f32, we need alloc src.typesize * // dst.elemcount. - ggml_cann_pool_alloc im2col_allocator( - ctx.pool(), - ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor); - void* tmp_im2col_buffer = im2col_allocator.get(); + ggml_cann_pool_alloc im2col_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor); + void * tmp_im2col_buffer = im2col_allocator.get(); - aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor( - tmp_im2col_buffer, ggml_cann_type_mapping(src1->type), - ggml_type_size(src1->type), tmp_im2col_ne, tmp_im2col_nb, - GGML_MAX_DIMS - 1, ACL_FORMAT_ND); + aclTensor * tmp_im2col_tensor = + ggml_cann_create_tensor(tmp_im2col_buffer, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), + tmp_im2col_ne, tmp_im2col_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - std::vector kernel_dims = {KH, KW}; - std::vector dilation_size = {d1, d0}; - std::vector padding_dims = {p1, p0}; - std::vector stride_dims = {s1, s0}; - auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2); - auto* dilations = aclCreateIntArray(dilation_size.data(), 2); - auto* paddings = aclCreateIntArray(padding_dims.data(), 2); - auto* strides = aclCreateIntArray(stride_dims.data(), 2); - GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations, - paddings, strides, tmp_im2col_tensor); + std::vector kernel_dims = { KH, KW }; + std::vector dilation_size = { d1, d0 }; + std::vector padding_dims = { p1, p0 }; + std::vector stride_dims = { s1, s0 }; + auto * kernel_size = aclCreateIntArray(kernel_dims.data(), 2); + auto * dilations = aclCreateIntArray(dilation_size.data(), 2); + auto * paddings = aclCreateIntArray(padding_dims.data(), 2); + auto * strides = aclCreateIntArray(stride_dims.data(), 2); + GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations, paddings, strides, tmp_im2col_tensor); // Cast if dst is f16. - aclTensor* tmp_cast_tensor = nullptr; + aclTensor * tmp_cast_tensor = nullptr; ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool()); - void* tmp_cast_buffer = nullptr; + void * tmp_cast_buffer = nullptr; if (src1->type != dst->type) { tmp_cast_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor); tmp_cast_buffer = tmp_cast_allocator.get(); @@ -1227,26 +1197,22 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1]; } - tmp_cast_tensor = ggml_cann_create_tensor( - tmp_cast_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, - GGML_MAX_DIMS - 1, ACL_FORMAT_ND); + tmp_cast_tensor = + ggml_cann_create_tensor(tmp_cast_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + tmp_im2col_ne, temp_cast_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type)); } // post-processing if (is_2D) { - ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor, - tmp_im2col_tensor); + ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor, tmp_im2col_tensor); } else { - std::vector im2col_op_params = { - KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor}; - ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor, - tmp_im2col_tensor, im2col_op_params); + std::vector im2col_op_params = { KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor }; + ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor, tmp_im2col_tensor, im2col_op_params); } - ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor, - kernel_size, dilations, paddings, strides); + ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor, kernel_size, dilations, paddings, + strides); } /** @@ -1262,136 +1228,123 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param ctx The context for the CANN backend operations. * @param acl_src The tensor on which the exponential function will be applied. */ -static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { +static void aclnn_exp(ggml_backend_cann_context & ctx, aclTensor * acl_src) { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src); } -void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst) { - if(acl_dst == nullptr) { +void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + if (acl_dst == nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src); } else { GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); } } -void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst) { - if(acl_dst == nullptr) { +void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + if (acl_dst == nullptr) { GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src); } else { GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); } } -void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, - ggml_tensor* dst) { - const ggml_tensor* src = dst->src[0]; +void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src = dst->src[0]; GGML_ASSERT(src->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int dim = dst->op_params[0]; + const int dim = dst->op_params[0]; const int max_period = dst->op_params[1]; - int half = dim / 2; + int half = dim / 2; - aclTensor* acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_src = ggml_cann_create_tensor(src); // arange: [0, ..., half) - float start = 0; - float stop = half; - float step = 1; + float start = 0; + float stop = half; + float step = 1; int64_t n_elements_arange = half; - int64_t tmp_arange_ne[] = {half}; - size_t tmp_arange_nb[] = {sizeof(dst->type)}; + int64_t tmp_arange_ne[] = { half }; + size_t tmp_arange_nb[] = { sizeof(dst->type) }; ggml_cann_pool_alloc arange_allocator(ctx.pool(), half * sizeof(dst->type)); - void* tmp_arange_buffer = arange_allocator.get(); - aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( - tmp_arange_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_arange_ne, tmp_arange_nb, - GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + void * tmp_arange_buffer = arange_allocator.get(); + aclTensor * tmp_arange_tensor = + ggml_cann_create_tensor(tmp_arange_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + tmp_arange_ne, tmp_arange_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_arange(ctx, tmp_arange_tensor, start, stop, step, n_elements_arange); // freq float freq_param = -logf(max_period) / half; - bool inplace = true; + bool inplace = true; aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace); aclnn_exp(ctx, tmp_arange_tensor); // permute: src [0,1,2,3]->[0,1,3,2] - int64_t tmp_permute_ne[] = {src->ne[1], src->ne[0], src->ne[2], src->ne[3]}; - size_t tmp_permute_nb[GGML_MAX_DIMS]; + int64_t tmp_permute_ne[] = { src->ne[1], src->ne[0], src->ne[2], src->ne[3] }; + size_t tmp_permute_nb[GGML_MAX_DIMS]; tmp_permute_nb[0] = ggml_type_size(src->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; } ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src)); - void* tmp_permute_buffer = permute_allocator.get(); - aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( - tmp_permute_buffer, ggml_cann_type_mapping(src->type), - ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 1, 3, 2}; - int64_t num_dims = 4; + void * tmp_permute_buffer = permute_allocator.get(); + aclTensor * tmp_permute_tensor = + ggml_cann_create_tensor(tmp_permute_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), + tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + int64_t permute_dim[] = { 0, 1, 3, 2 }; + int64_t num_dims = 4; aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims); // timestep * freq - int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2], - src->ne[3]}; - size_t tmp_mul_nb[GGML_MAX_DIMS]; + int64_t tmp_mul_ne[] = { src->ne[1] * half, src->ne[0], src->ne[2], src->ne[3] }; + size_t tmp_mul_nb[GGML_MAX_DIMS]; tmp_mul_nb[0] = ggml_type_size(src->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { tmp_mul_nb[i] = tmp_mul_nb[i - 1] * tmp_mul_ne[i - 1]; } - int mul_nelements = - src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3]; + int mul_nelements = src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3]; - ggml_cann_pool_alloc mul_allocator( - ctx.pool(), mul_nelements * ggml_type_size(src->type)); - void* tmp_mul_buffer = mul_allocator.get(); - aclTensor* tmp_mul_tensor = ggml_cann_create_tensor( - tmp_mul_buffer, ggml_cann_type_mapping(src->type), - ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); + ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type)); + void * tmp_mul_buffer = mul_allocator.get(); + aclTensor * tmp_mul_tensor = + ggml_cann_create_tensor(tmp_mul_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), + tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor); // cos - ggml_cann_pool_alloc cos_allocator( - ctx.pool(), mul_nelements * ggml_type_size(src->type)); - void* tmp_cos_buffer = cos_allocator.get(); - aclTensor* tmp_cos_tensor = ggml_cann_create_tensor( - tmp_cos_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); + ggml_cann_pool_alloc cos_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type)); + void * tmp_cos_buffer = cos_allocator.get(); + aclTensor * tmp_cos_tensor = + ggml_cann_create_tensor(tmp_cos_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, tmp_mul_tensor, tmp_cos_tensor); // sin - ggml_cann_pool_alloc sin_allocator( - ctx.pool(), mul_nelements * ggml_type_size(src->type)); - void* tmp_sin_buffer = sin_allocator.get(); - aclTensor* tmp_sin_tensor = ggml_cann_create_tensor( - tmp_sin_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, - ACL_FORMAT_ND); + ggml_cann_pool_alloc sin_allocator(ctx.pool(), mul_nelements * ggml_type_size(src->type)); + void * tmp_sin_buffer = sin_allocator.get(); + aclTensor * tmp_sin_tensor = + ggml_cann_create_tensor(tmp_sin_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, tmp_mul_tensor, tmp_sin_tensor); // concat - int64_t concat_dim = 3; - aclTensor* acl_dst = ggml_cann_create_tensor(dst); - aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor}; - aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + int64_t concat_dim = 3; + aclTensor * acl_dst = ggml_cann_create_tensor(dst); + aclTensor * tensors[] = { tmp_cos_tensor, tmp_sin_tensor }; + aclTensorList * tensor_list = aclCreateTensorList(tensors, 2); aclnn_concat(ctx, tensor_list, acl_dst, concat_dim); // release // segmentation fault when delete both tensorList and his elements. - ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor, - tmp_permute_tensor, tmp_mul_tensor, acl_dst); + ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor, tmp_permute_tensor, tmp_mul_tensor, + acl_dst); } /** @@ -1410,8 +1363,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, * @param acl_exp The exponent tensor, each element of which is used to raise * the corresponding element in the destination tensor. */ -static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, - aclTensor* acl_dst, aclTensor* acl_exp) { +static void aclnn_pow_tensor_tensor(ggml_backend_cann_context & ctx, aclTensor * acl_dst, aclTensor * acl_exp) { GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp); } @@ -1436,25 +1388,29 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, * @param step Step size for the exponent increment. * @param dtype Data type for slope tensor. */ -static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, - float m, int64_t size, float start, float stop, float step, ggml_type dtype){ - aclDataType acl_type = ggml_cann_type_mapping(dtype); - size_t type_size = ggml_type_size(dtype); +static void aclnn_get_slope_inner(ggml_backend_cann_context & ctx, + void * slope_buffer, + float m, + int64_t size, + float start, + float stop, + float step, + ggml_type dtype) { + aclDataType acl_type = ggml_cann_type_mapping(dtype); + size_t type_size = ggml_type_size(dtype); - int64_t ne[] = {size}; - size_t nb[] = {type_size}; + int64_t ne[] = { size }; + size_t nb[] = { type_size }; ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size); - void* arange_buffer = arange_allocator.get(); + void * arange_buffer = arange_allocator.get(); - aclTensor* arange_tensor = ggml_cann_create_tensor( - arange_buffer, acl_type, type_size, ne, nb, 1); + aclTensor * arange_tensor = ggml_cann_create_tensor(arange_buffer, acl_type, type_size, ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); - aclTensor* slope_tensor = ggml_cann_create_tensor( - slope_buffer, acl_type, type_size, ne, nb, 1); + aclTensor * slope_tensor = ggml_cann_create_tensor(slope_buffer, acl_type, type_size, ne, nb, 1); - aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); + aclScalar * sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); @@ -1486,8 +1442,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu * @param dtype Data type for slope tensor. * */ -static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, - void* slope_buffer, float max_bias, ggml_type dtype) { +static void aclnn_get_slope(ggml_backend_cann_context & ctx, + int64_t n_head, + void * slope_buffer, + float max_bias, + ggml_type dtype) { const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); float m0 = powf(2.0f, -(max_bias) / n_head_log2); @@ -1511,9 +1470,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner( - ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), - m1, count, start, end + 1, step, dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, + dtype); } } @@ -1538,17 +1496,19 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, * - Write data into dst_ptr using only the shape information of the dst tensor. * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting. */ -static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, - ggml_tensor* dst, void* dst_ptr, float max_bias) { - void* slope_buffer = nullptr; - void* bias_buffer = nullptr; +static void aclnn_add_alibi(ggml_backend_cann_context & ctx, + ggml_tensor * mask, + ggml_tensor * dst, + void * dst_ptr, + float max_bias) { + void * slope_buffer = nullptr; + void * bias_buffer = nullptr; if (max_bias > 0.0f) { - int64_t n_heads = dst->ne[2]; + int64_t n_heads = dst->ne[2]; ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); slope_buffer = slope_allocator.get(); - ggml_cann_pool_alloc bias_allocator( - ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); + ggml_cann_pool_alloc bias_allocator(ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); bias_buffer = bias_allocator.get(); aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32); } @@ -1559,16 +1519,12 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, // broadcast the mask across rows int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 }; - size_t mask_nb[] = { - mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], - mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] - }; + size_t mask_nb[] = { mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], + mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] }; int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 }; - size_t dst_nb[] = { - dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], - dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] - }; + size_t dst_nb[] = { dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], + dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] }; // slope is a 1 dim tensor, slope.ne2 == dst.ne2 int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 }; @@ -1578,17 +1534,13 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1]; } - aclTensor* acl_slope = ggml_cann_create_tensor( - slope_buffer, ACL_FLOAT, sizeof(float), - slope_ne, slope_nb, GGML_MAX_DIMS + 2); - aclTensor* acl_mask = ggml_cann_create_tensor( - mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); + aclTensor * acl_slope = + ggml_cann_create_tensor(slope_buffer, ACL_FLOAT, sizeof(float), slope_ne, slope_nb, GGML_MAX_DIMS + 2); + aclTensor * acl_mask = ggml_cann_create_tensor(mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); // write data into dst_ptr using only the shape information of the dst tensor. - aclTensor* acl_dst = ggml_cann_create_tensor( - dst_ptr, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst_ne, dst_nb, - GGML_MAX_DIMS + 2); + aclTensor * acl_dst = ggml_cann_create_tensor(dst_ptr, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + dst_ne, dst_nb, GGML_MAX_DIMS + 2); if (max_bias > 0.0f) { int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 }; @@ -1597,9 +1549,8 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1]; } - aclTensor* bias_tensor = ggml_cann_create_tensor( - bias_buffer, ACL_FLOAT, sizeof(float), - bias_ne, bias_nb, GGML_MAX_DIMS + 2); + aclTensor * bias_tensor = + ggml_cann_create_tensor(bias_buffer, ACL_FLOAT, sizeof(float), bias_ne, bias_nb, GGML_MAX_DIMS + 2); aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor); aclnn_add(ctx, acl_dst, bias_tensor); @@ -1628,17 +1579,16 @@ void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) { * @param acl_dst The destination tensor where the softmax results will be * stored. */ -static void aclnn_softmax(ggml_backend_cann_context & ctx, - aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) { +static void aclnn_softmax(ggml_backend_cann_context & ctx, aclTensor * acl_src, int64_t dim, aclTensor * acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); } void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; // mask + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; // mask - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src0 = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); float scale = 1.0f; float max_bias = 0.0f; @@ -1647,12 +1597,11 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // input mul scale - aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + aclScalar * acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); - void* src_tensor_buffer = src_tensor_allocator.get(); - aclTensor* softmax_tensor = ggml_cann_create_tensor( - src_tensor_buffer, ggml_cann_type_mapping(src0->type), - ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS); + void * src_tensor_buffer = src_tensor_allocator.get(); + aclTensor * softmax_tensor = ggml_cann_create_tensor(src_tensor_buffer, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src0->ne, src0->nb, GGML_MAX_DIMS); aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false); @@ -1684,29 +1633,31 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { * @param index The index tensor specifying the indices to select from the source tensor. * @param type The data type of the source and destination tensors. */ -static void aclnn_index_select_4d(ggml_backend_cann_context& ctx, - void* src_buffer,int64_t* src_ne, size_t* src_nb, - void* dst_buffer, int64_t* dst_ne, size_t* dst_nb, - ggml_tensor* index, ggml_type type) { +static void aclnn_index_select_4d(ggml_backend_cann_context & ctx, + void * src_buffer, + int64_t * src_ne, + size_t * src_nb, + void * dst_buffer, + int64_t * dst_ne, + size_t * dst_nb, + ggml_tensor * index, + ggml_type type) { for (int64_t i = 0; i < src_ne[3]; i++) { for (int64_t j = 0; j < src_ne[2]; j++) { // src - aclTensor* acl_src_tensor = ggml_cann_create_tensor( - (char*)src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), - src_ne, src_nb, 2); + aclTensor * acl_src_tensor = + ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], + ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); // index - aclTensor* acl_index = ggml_cann_create_tensor( - (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), - index->ne, index->nb, 1); + aclTensor * acl_index = ggml_cann_create_tensor( + (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], + ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); // out - aclTensor* acl_out = ggml_cann_create_tensor( - (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), - dst_ne, dst_nb, 2); + aclTensor * acl_out = + ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], + ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor, 0, acl_index, acl_out); ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out); } @@ -1733,162 +1684,154 @@ static void aclnn_index_select_4d(ggml_backend_cann_context& ctx, * @param index The index tensor specifying target positions in the destination tensor. * @param type The data type of the source and destination tensors. */ -static void aclnn_index_copy_4d(ggml_backend_cann_context& ctx, - void* src_buffer,int64_t* src_ne, size_t* src_nb, - void* dst_buffer, int64_t* dst_ne, size_t* dst_nb, - ggml_tensor* index, ggml_type type) { +static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx, + void * src_buffer, + int64_t * src_ne, + size_t * src_nb, + void * dst_buffer, + int64_t * dst_ne, + size_t * dst_nb, + ggml_tensor * index, + ggml_type type) { for (int64_t i = 0; i < src_ne[3]; i++) { for (int64_t j = 0; j < src_ne[2]; j++) { // src - aclTensor* acl_src_tensor = ggml_cann_create_tensor( - (char*)src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), - src_ne, src_nb, 2); + aclTensor * acl_src_tensor = + ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], + ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); // index - aclTensor* acl_index = ggml_cann_create_tensor( - (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), - index->ne, index->nb, 1); + aclTensor * acl_index = ggml_cann_create_tensor( + (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], + ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); // out - aclTensor* acl_out = ggml_cann_create_tensor( - (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), - dst_ne, dst_nb, 2); + aclTensor * acl_out = + ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], + ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out, 0, acl_index, acl_src_tensor); ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out); } } } -void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; // src - ggml_tensor* src1 = dst->src[1]; // index +void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src1 = dst->src[1]; // index GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); switch (src0->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: - if(src0->type == dst->type) { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); + if (src0->type == dst->type) { + aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, + dst->type); } else { - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; + aclTensor * acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); + void * src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; src_trans_nb[0] = dst->nb[0]; for (int i = 1; i < GGML_MAX_DIMS; i++) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclTensor * src_trans_tensor = + ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); + aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, + dst->type); ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); } break; - case GGML_TYPE_Q8_0: { - // add 1 dim for bcast mul. - size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], - dequant_nb[GGML_MAX_DIMS + 1]; - int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], - *dequant_ne; - int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] - weight_ne[0] = QK8_0; - weight_ne[1] = src0->ne[0] / QK8_0; - weight_nb[0] = sizeof(int8_t); - weight_nb[1] = weight_nb[0] * weight_ne[0]; - for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { - weight_ne[i] = src0->ne[i - 1]; - weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; - } - // [3,4,5,64] -> [3,4,5,2,1] - scale_ne[0] = 1; - scale_ne[1] = src0->ne[0] / QK8_0; - scale_nb[0] = sizeof(uint16_t); - scale_nb[1] = scale_nb[0] * scale_ne[0]; - for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { - scale_ne[i] = src0->ne[i - 1]; - scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; - } - // [3,4,5,64] -> [3,4,5,2,32] - dequant_ne = weight_ne; - dequant_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { - dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; - } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); - ggml_cann_pool_alloc dequant_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type)); - aclTensor* acl_weight_tensor = ggml_cann_create_tensor( - src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, - GGML_MAX_DIMS + 1); - aclTensor* acl_scale_tensor = ggml_cann_create_tensor( - src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, - GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); - aclTensor* dequant_tensor = ggml_cann_create_tensor( - dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), - dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - dequant_nb[0] = ggml_type_size(dst->type); - dequant_ne = src0->ne; - for (int i = 1; i < GGML_MAX_DIMS; i++) { - dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; - } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), - dequant_ne, dequant_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); + case GGML_TYPE_Q8_0: + { + // add 1 dim for bcast mul. + size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; + int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; + int64_t scale_offset = 0; + // [3,4,5,64] -> [3,4,5,2,32] + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; + for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { + weight_ne[i] = src0->ne[i - 1]; + weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; + } + // [3,4,5,64] -> [3,4,5,2,1] + scale_ne[0] = 1; + scale_ne[1] = src0->ne[0] / QK8_0; + scale_nb[0] = sizeof(uint16_t); + scale_nb[1] = scale_nb[0] * scale_ne[0]; + for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { + scale_ne[i] = src0->ne[i - 1]; + scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; + } + // [3,4,5,64] -> [3,4,5,2,32] + dequant_ne = weight_ne; + dequant_nb[0] = ggml_type_size(dst->type); + for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { + dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; + } + scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + aclTensor * acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, + weight_nb, GGML_MAX_DIMS + 1); + aclTensor * acl_scale_tensor = + ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + aclTensor * dequant_tensor = + ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); + dequant_nb[0] = ggml_type_size(dst->type); + dequant_ne = src0->ne; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; + } + aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, + dst->nb, src1, dst->type); - ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - break; - } + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); + break; + } default: GGML_ABORT("Unsupported tensor type for GGML_OP_GET_ROWS"); break; } } -void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; // src - ggml_tensor* src1 = dst->src[1]; // index +void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src1 = dst->src[1]; // index switch (dst->type) { - case GGML_TYPE_F32: { - aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - break; - } - case GGML_TYPE_F16: { - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(uint16_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + case GGML_TYPE_F32: + { + aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); + break; + } + case GGML_TYPE_F16: + { + aclTensor * acl_src0 = ggml_cann_create_tensor(src0); + ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); + void * src_trans_buffer = src_buffer_allocator.get(); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(uint16_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor * src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); + aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, + dst->type); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); + break; } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), - src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, - dst->data, dst->ne, dst->nb, - src1, dst->type); - ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); - break; - } default: GGML_ABORT("Unsupported tensor type for GGML_OP_SET_ROWS"); break; @@ -1910,12 +1853,13 @@ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param repeats The number of times each element will be repeated. * @param output_size The size of the output tensor. */ -static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, - aclTensor* acl_src, aclTensor* acl_dst, - int64_t dim, int64_t repeats, - int64_t output_size) { - GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim, - output_size, acl_dst); +static void aclnn_repeat_interleave(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + int64_t dim, + int64_t repeats, + int64_t output_size) { + GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim, output_size, acl_dst); } /** @@ -1930,10 +1874,9 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, * @param dst The destination tensor where the result of the matrix * multiplication will be stored. */ -static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, - ggml_tensor* dst) { - ggml_tensor* weight = dst->src[0]; // weight - ggml_tensor* input = dst->src[1]; // input +static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * weight = dst->src[0]; // weight + ggml_tensor * input = dst->src[1]; // input // when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto // broadcast, when weight ne2 or ne3 is not 1, weight need repeat. @@ -1948,27 +1891,21 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, } } - aclTensor* acl_input_tensor = - ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims); - int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0], - bcast_weight_ne[2], bcast_weight_ne[3], - bcast_weight_ne[4], bcast_weight_ne[5]}; - size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0], - bcast_weight_nb[2], bcast_weight_nb[3], - bcast_weight_nb[4], bcast_weight_nb[5]}; - aclTensor* acl_weight_tensor; + aclTensor * acl_input_tensor = ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims); + int64_t transpose_ne[] = { bcast_weight_ne[1], bcast_weight_ne[0], bcast_weight_ne[2], + bcast_weight_ne[3], bcast_weight_ne[4], bcast_weight_ne[5] }; + size_t transpose_nb[] = { bcast_weight_nb[1], bcast_weight_nb[0], bcast_weight_nb[2], + bcast_weight_nb[3], bcast_weight_nb[4], bcast_weight_nb[5] }; + aclTensor * acl_weight_tensor; // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { - acl_weight_tensor = - ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); + acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { - acl_weight_tensor = - ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); + acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); } - aclTensor* acl_dst = - ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims); switch (n_dims) { case 2: @@ -2000,11 +1937,9 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, * @param dst The destination tensor where the result of the matrix * multiplication will be stored. */ -static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, - ggml_tensor* dst, - const enum ggml_type type) { - ggml_tensor* src0 = dst->src[0]; // weight - ggml_tensor* src1 = dst->src[1]; // input +static void ggml_cann_mul_mat_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst, const enum ggml_type type) { + ggml_tensor * src0 = dst->src[0]; // weight + ggml_tensor * src1 = dst->src[1]; // input // The shape of the weight is NCHW. // Matrix multiplication uses HW dims. @@ -2018,56 +1953,52 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, } else { GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT"); } - float weight_nb[] = {src0->ne[0] * weight_elem_size, weight_elem_size}; + float weight_nb[] = { src0->ne[0] * weight_elem_size, weight_elem_size }; size_t weight_stride = src0->ne[1] * src0->ne[0] * weight_elem_size; - size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3]; + size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3]; // scale stored at the end of weight. Also need transpose. size_t scale_elem_size = sizeof(uint16_t); - size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size, - scale_elem_size}; - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; - char* scale_offset = (char*)src0->data + weight_size; + size_t scale_nb[] = { src0->ne[0] / QK8_0 * scale_elem_size, scale_elem_size }; + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + char * scale_offset = (char *) src0->data + weight_size; // input - size_t input_elem_size = sizeof(uint16_t); - int64_t input_ne[] = {src1->ne[0], src1->ne[1]}; - size_t input_nb[] = {input_elem_size, input_ne[0] * input_elem_size}; - size_t input_stride = input_ne[0] * input_ne[1] * input_elem_size; + size_t input_elem_size = sizeof(uint16_t); + int64_t input_ne[] = { src1->ne[0], src1->ne[1] }; + size_t input_nb[] = { input_elem_size, input_ne[0] * input_elem_size }; + size_t input_stride = input_ne[0] * input_ne[1] * input_elem_size; ggml_cann_pool_alloc input_alloctor(ctx.pool()); - void* input_buffer = src1->data; + void * input_buffer = src1->data; // case in if (src1->type != GGML_TYPE_F16) { - aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1); - input_buffer = - input_alloctor.alloc(ggml_nelements(src1) * input_elem_size); + aclTensor * acl_src1_tensor = ggml_cann_create_tensor(src1); + input_buffer = input_alloctor.alloc(ggml_nelements(src1) * input_elem_size); - int64_t* input_cast_ne = src1->ne; - size_t input_cast_nb[GGML_MAX_DIMS]; + int64_t * input_cast_ne = src1->ne; + size_t input_cast_nb[GGML_MAX_DIMS]; input_cast_nb[0] = sizeof(uint16_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1]; } - aclTensor* acl_input_tensor = ggml_cann_create_tensor( - input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne, - input_cast_nb, GGML_MAX_DIMS); + aclTensor * acl_input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, input_elem_size, + input_cast_ne, input_cast_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16); ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor); } // output - size_t output_elem_size = sizeof(uint16_t); - size_t output_nb[] = {output_elem_size, dst->ne[0] * output_elem_size}; + size_t output_elem_size = sizeof(uint16_t); + size_t output_nb[] = { output_elem_size, dst->ne[0] * output_elem_size }; ggml_cann_pool_alloc output_allocator(ctx.pool()); - void* output_buffer = - output_allocator.alloc(ggml_nelements(dst) * output_elem_size); - size_t output_stride = dst->ne[0] * dst->ne[1] * output_elem_size; + void * output_buffer = output_allocator.alloc(ggml_nelements(dst) * output_elem_size); + size_t output_stride = dst->ne[0] * dst->ne[1] * output_elem_size; // aclnn - int64_t max_elem_size = 65535; - int64_t split_size = (src0->ne[1] / max_elem_size) + 1; + int64_t max_elem_size = 65535; + int64_t split_size = (src0->ne[1] / max_elem_size) + 1; ggml_cann_pool_alloc workspace_allocator(ctx.pool()); for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) { for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) { @@ -2077,71 +2008,57 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, int64_t batch1 = (n1 * src1->ne[2]) + c1; int64_t batch0 = (n0 * src0->ne[2]) + c0; - aclTensor* acl_input_tensor = ggml_cann_create_tensor( - (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16, - input_elem_size, input_ne, input_nb, 2); + aclTensor * acl_input_tensor = ggml_cann_create_tensor((char *) input_buffer + batch1 * input_stride, + ACL_FLOAT16, input_elem_size, input_ne, input_nb, 2); // first split int64_t weight_ne_offset = 0; - int64_t weight_ne[2] = { - max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size, - src0->ne[0]}; - int64_t scale_ne_offset = 0; - int64_t scale_ne[2] = {weight_ne[0], weight_ne[1] / QK8_0}; + int64_t weight_ne[2] = { max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size, src0->ne[0] }; + int64_t scale_ne_offset = 0; + int64_t scale_ne[2] = { weight_ne[0], weight_ne[1] / QK8_0 }; int64_t output_ne_offset = 0; - int64_t output_ne[2] = {weight_ne[0], dst->ne[1]}; + int64_t output_ne[2] = { weight_ne[0], dst->ne[1] }; - aclTensor* acl_weight_tensor = ggml_cann_create_tensor( - (char*)src0->data + batch0 * weight_stride, - ggml_cann_type_mapping(type), weight_elem_size, weight_ne, - weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); - aclTensor* acl_scale_tensor = ggml_cann_create_tensor( - scale_offset + batch0 * scale_stride, ACL_FLOAT16, - scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND, - scale_ne_offset); - aclTensor* acl_output_tensor = ggml_cann_create_tensor( - (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, - output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, - output_ne_offset); + aclTensor * acl_weight_tensor = + ggml_cann_create_tensor((char *) src0->data + batch0 * weight_stride, ggml_cann_type_mapping(type), + weight_elem_size, weight_ne, weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); + aclTensor * acl_scale_tensor = + ggml_cann_create_tensor(scale_offset + batch0 * scale_stride, ACL_FLOAT16, scale_elem_size, scale_ne, + scale_nb, 2, ACL_FORMAT_ND, scale_ne_offset); + aclTensor * acl_output_tensor = + ggml_cann_create_tensor((char *) output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size, + output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); int64_t antiquantGroupSize = 0; if (src0->ne[0] > QK8_0) { antiquantGroupSize = QK8_0; } - GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, - acl_weight_tensor, acl_scale_tensor, nullptr, - nullptr, nullptr, nullptr, antiquantGroupSize, - acl_output_tensor); + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, acl_weight_tensor, + acl_scale_tensor, nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); // other splits for (int64_t split = 1; split < split_size; split++) { - weight_ne_offset += - weight_elem_size * weight_ne[0] * weight_ne[1]; - weight_ne[0] = max_elem_size * (split + 1) > src0->ne[1] - ? src0->ne[1] - (max_elem_size * split) - : max_elem_size; + weight_ne_offset += weight_elem_size * weight_ne[0] * weight_ne[1]; + weight_ne[0] = + max_elem_size * (split + 1) > src0->ne[1] ? src0->ne[1] - (max_elem_size * split) : max_elem_size; scale_ne_offset += scale_elem_size * scale_ne[0] * scale_ne[1]; scale_ne[0] = weight_ne[0]; - output_ne_offset += - output_elem_size * output_ne[0] * output_ne[1]; + output_ne_offset += output_elem_size * output_ne[0] * output_ne[1]; output_ne[0] = weight_ne[0]; - acl_weight_tensor = ggml_cann_create_tensor( - (char*)src0->data + batch0 * weight_stride, - ggml_cann_type_mapping(type), weight_elem_size, weight_ne, - weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); - acl_scale_tensor = ggml_cann_create_tensor( - scale_offset + batch0 * scale_stride, ACL_FLOAT16, - scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND, - scale_ne_offset); - acl_output_tensor = ggml_cann_create_tensor( - (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, - output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, - output_ne_offset); - GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, - acl_weight_tensor, acl_scale_tensor, nullptr, - nullptr, nullptr, nullptr, antiquantGroupSize, - acl_output_tensor); + acl_weight_tensor = + ggml_cann_create_tensor((char *) src0->data + batch0 * weight_stride, ggml_cann_type_mapping(type), + weight_elem_size, weight_ne, weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); + acl_scale_tensor = + ggml_cann_create_tensor(scale_offset + batch0 * scale_stride, ACL_FLOAT16, scale_elem_size, + scale_ne, scale_nb, 2, ACL_FORMAT_ND, scale_ne_offset); + acl_output_tensor = + ggml_cann_create_tensor((char *) output_buffer + batch1 * output_stride, ACL_FLOAT16, + output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, acl_weight_tensor, + acl_scale_tensor, nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); } @@ -2151,24 +2068,23 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, // cast out if (dst->type != GGML_TYPE_F16) { - int64_t* output_cast_ne = dst->ne; - size_t output_cast_nb[GGML_MAX_DIMS]; + int64_t * output_cast_ne = dst->ne; + size_t output_cast_nb[GGML_MAX_DIMS]; output_cast_nb[0] = sizeof(uint16_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1]; } - aclTensor* acl_output_tensor = ggml_cann_create_tensor( - output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, - output_cast_nb, GGML_MAX_DIMS); - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclTensor * acl_output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size, + output_cast_ne, output_cast_nb, GGML_MAX_DIMS); + aclTensor * acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor); } } -void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { const enum ggml_type type = dst->src[0]->type; switch (type) { case GGML_TYPE_F32: @@ -2201,10 +2117,13 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param dims An array specifying the dimensions along which elements are * shifted. */ -static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, int64_t* shifts, int64_t* dims) { - aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1); - aclIntArray* acl_dims = aclCreateIntArray(dims, 1); +static void aclnn_roll(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_dst, + int64_t * shifts, + int64_t * dims) { + aclIntArray * acl_shifts = aclCreateIntArray(shifts, 1); + aclIntArray * acl_dims = aclCreateIntArray(dims, 1); GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst); ggml_cann_release_resources(ctx, acl_shifts, acl_dims); } @@ -2222,12 +2141,14 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param index_num The number of positions specified in the index array. * @param value The scalar value used to fill the specified positions. */ -static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, - aclTensor* acl_src, int64_t dim, - int64_t* index, int64_t index_num, - float value) { - aclIntArray* acl_index = aclCreateIntArray(index, index_num); - aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); +static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + int64_t dim, + int64_t * index, + int64_t index_num, + float value) { + aclIntArray * acl_index = aclCreateIntArray(index, index_num); + aclScalar * acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value); ggml_cann_release_resources(ctx, acl_index, acl_value); } @@ -2262,85 +2183,82 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * @param is_neox Whether to use Neox-style repeat strategy * (dim expansion vs repeat_interleave). */ -static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, - float* corr_dims, float ext_factor, - float theta_scale, float freq_scale, - float attn_factor, bool is_neox) { - ggml_tensor* src0 = dst->src[0]; // input - ggml_tensor* src1 = dst->src[1]; // position - ggml_tensor* src2 = dst->src[2]; // freq_factors +static void aclnn_cache_init(ggml_backend_cann_context & ctx, + ggml_tensor * dst, + float * corr_dims, + float ext_factor, + float theta_scale, + float freq_scale, + float attn_factor, + bool is_neox) { + ggml_tensor * src0 = dst->src[0]; // input + ggml_tensor * src1 = dst->src[1]; // position + ggml_tensor * src2 = dst->src[2]; // freq_factors - if(src2 == nullptr && ctx.rope_cache.cached - && ctx.rope_cache.ext_factor == ext_factor - && ctx.rope_cache.theta_scale == theta_scale - && ctx.rope_cache.freq_scale == freq_scale - && ctx.rope_cache.attn_factor == attn_factor - && ctx.rope_cache.is_neox == is_neox) { + if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor && + ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale && + ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) { // use cache. return; } int64_t theta_scale_length = src0->ne[0] / 2; - int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; - size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float), - theta_scale_length * sizeof(float)}; + int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 }; + size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) }; GGML_ASSERT(src1->type == GGML_TYPE_I32); int64_t position_length = src1->ne[0]; - int64_t position_ne[] = {1, 1, position_length, 1}; - size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), - sizeof(int32_t) * position_length}; + int64_t position_ne[] = { 1, 1, position_length, 1 }; + size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length }; - int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; - size_t theta_nb[GGML_MAX_DIMS]; + int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 }; + size_t theta_nb[GGML_MAX_DIMS]; theta_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; } // theta_scale arange, [0,1,...,ne00/2 - 1] - aclTensor* acl_theta_scale_tensor = nullptr; + aclTensor * acl_theta_scale_tensor = nullptr; // cache theta scale if (ctx.rope_cache.theta_scale_length != theta_scale_length || // theta_scale and freq_scale should not change during the current token inference process, // so we can directly use == here instead of comparing the absolute difference. - ctx.rope_cache.theta_scale != theta_scale || - ctx.rope_cache.freq_scale != freq_scale) { - + ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { ctx.rope_cache.theta_scale_length = theta_scale_length; if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), + ACL_MEM_MALLOC_HUGE_FIRST)); - acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - float start = 0; - float step = 1; - float stop = theta_scale_length; + float start = 0; + float step = 1; + float stop = theta_scale_length; float n_elements = theta_scale_length; aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); - aclTensor* acl_yarn_ramp_tensor = nullptr; + aclTensor * acl_yarn_ramp_tensor = nullptr; if (ext_factor != 0) { // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void* yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - float zero_value = 0, one_value = 1; - float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); - aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); - aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT); - aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT); - aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT); - aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT); + void * yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, + theta_scale_nb, GGML_MAX_DIMS); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); + aclScalar * zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT); + aclScalar * one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT); + aclScalar * denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT); + aclScalar * ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe); @@ -2357,9 +2275,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse // cache freq_scale + (freq_scale - 1) * ramp_mix - float freq_scale_1 = freq_scale - 1; - aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT); - aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT); + float freq_scale_1 = freq_scale - 1; + aclScalar * freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT); + aclScalar * freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one); @@ -2367,9 +2285,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } // power - aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, - acl_theta_scale_tensor); + aclScalar * acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, acl_theta_scale_tensor); if (ext_factor != 0) { aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor); @@ -2380,22 +2297,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale); } else { // use cache - acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); // freq_factors if (src2) { freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); - void* freq_fac_res_ptr = freq_fac_res_allocator.get(); - aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( - src2->data, ggml_cann_type_mapping(src2->type), - ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor( - freq_fac_res_ptr, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + void * freq_fac_res_ptr = freq_fac_res_allocator.get(); + aclTensor * acl_freq_factors_tensor = + ggml_cann_create_tensor(src2->data, ggml_cann_type_mapping(src2->type), ggml_type_size(src2->type), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + aclTensor * acl_freq_fac_res_tensor = ggml_cann_create_tensor(freq_fac_res_ptr, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor); std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); @@ -2411,42 +2326,37 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); } int64_t repeat_theta_length = theta_scale_length * position_length * 2; - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } // position - aclTensor* acl_position_tensor = ggml_cann_create_tensor( - src1->data, ggml_cann_type_mapping(src1->type), - ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); + aclTensor * acl_position_tensor = + ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, + position_nb, GGML_MAX_DIMS); // power * position - int64_t theta_length = theta_scale_length * position_length; - ggml_cann_pool_alloc theta_allocator(ctx.pool(), - theta_length * sizeof(float)); - void* theta_buffer = theta_allocator.get(); + int64_t theta_length = theta_scale_length * position_length; + ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float)); + void * theta_buffer = theta_allocator.get(); - aclTensor* acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), - theta_ne, theta_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, - acl_theta_tensor); + aclTensor * acl_theta_tensor = + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, acl_theta_tensor); // sin/cos - ggml_cann_pool_alloc sin_allocator(ctx.pool(), - theta_length * sizeof(float)); - void* sin_buffer = sin_allocator.get(); - aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float)); + void * sin_buffer = sin_allocator.get(); + aclTensor * acl_sin_tensor = + ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); - ggml_cann_pool_alloc cos_allocator(ctx.pool(), - theta_length * sizeof(float)); - void* cos_buffer = cos_allocator.get(); - aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); + ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float)); + void * cos_buffer = cos_allocator.get(); + aclTensor * acl_cos_tensor = + ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); if (ext_factor != 0) { @@ -2459,81 +2369,79 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; - size_t sin_reshape_nb[GGML_MAX_DIMS]; + int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 }; + size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } - aclTensor* acl_sin_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclTensor* acl_cos_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor * acl_sin_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor * acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); // repeat if (is_neox) { - int64_t repeatsArray[] = {1, 1, 1, 2}; + int64_t repeatsArray[] = { 1, 1, 1, 2 }; aclnn_repeat(ctx, acl_sin_tensor, acl_sin_repeat_tensor, repeatsArray); aclnn_repeat(ctx, acl_cos_tensor, acl_cos_repeat_tensor, repeatsArray); } else { int64_t num_repeats = 2; - int64_t dim = 3; + int64_t dim = 3; int64_t output_size = theta_scale_length * num_repeats; - aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim, - num_repeats, output_size); - aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim, - num_repeats, output_size); + aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim, num_repeats, output_size); + aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim, num_repeats, output_size); } // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - ctx.rope_cache.ext_factor = ext_factor; + ctx.rope_cache.cached = true; + ctx.rope_cache.ext_factor = ext_factor; ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; + ctx.rope_cache.freq_scale = freq_scale; ctx.rope_cache.attn_factor = attn_factor; - ctx.rope_cache.is_neox = is_neox; + ctx.rope_cache.is_neox = is_neox; - ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, - acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, - acl_cos_repeat_tensor); + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, acl_theta_tensor, acl_sin_tensor, + acl_sin_repeat_tensor, acl_cos_tensor, acl_cos_repeat_tensor); } #ifdef __cplusplus extern "C" { #endif -aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize( - const aclTensor* x, const aclTensor* cos, const aclTensor* sin, - int64_t mode, const aclTensor* yOut, uint64_t* workspaceSize, - aclOpExecutor** executor); -aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, - uint64_t workspaceSize, - aclOpExecutor* executor, - aclrtStream stream); +aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize(const aclTensor * x, + const aclTensor * cos, + const aclTensor * sin, + int64_t mode, + const aclTensor * yOut, + uint64_t * workspaceSize, + aclOpExecutor ** executor); +aclnnStatus aclnnRotaryPositionEmbedding(void * workspace, + uint64_t workspaceSize, + aclOpExecutor * executor, + aclrtStream stream); #ifdef __cplusplus } #endif -void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; // input +void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // input // param - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t*)dst->op_params)[1]; - const int mode = ((int32_t*)dst->op_params)[2]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t*)dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); @@ -2542,123 +2450,111 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const float theta_scale = powf(freq_base, -2.0f / n_dims); float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, - beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, corr_dims, ext_factor, - theta_scale, freq_scale, attn_factor, is_neox); + aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); - int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; - size_t sin_reshape_nb[GGML_MAX_DIMS]; + int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; + size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } - aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor * acl_sin_reshape_tensor = ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor * acl_cos_reshape_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); #ifdef ASCEND_310P // Special ROPE operation for 310P // roll input - void* input_roll_buffer; - aclTensor* acl_minus_one_tensor; - void* minus_one_scale_buffer = nullptr; + void * input_roll_buffer; + aclTensor * acl_minus_one_tensor; + void * minus_one_scale_buffer = nullptr; ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); - ggml_cann_pool_alloc minus_one_scale_allocator( - ctx.pool(), sizeof(float) * src0->ne[0]); + ggml_cann_pool_alloc minus_one_scale_allocator(ctx.pool(), sizeof(float) * src0->ne[0]); if (!is_neox) { // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] - input_roll_buffer = roll_allocator.get(); - int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2), - src0->ne[2], src0->ne[3]}; - size_t input_roll_nb[GGML_MAX_DIMS]; + input_roll_buffer = roll_allocator.get(); + int64_t input_roll_ne[4] = { 2, src0->ne[1] * (src0->ne[0] / 2), src0->ne[2], src0->ne[3] }; + size_t input_roll_nb[GGML_MAX_DIMS]; input_roll_nb[0] = ggml_type_size(src0->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1]; } - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor( - src0->data, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); + aclTensor * acl_input_roll_tensor = + ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + input_roll_ne, input_roll_nb, GGML_MAX_DIMS); + aclTensor * acl_input_tensor = + ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + input_roll_ne, input_roll_nb, GGML_MAX_DIMS); - int64_t shifts[] = {1}; - int64_t dims[] = {3}; + int64_t shifts[] = { 1 }; + int64_t dims[] = { 3 }; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, 1, -1, 1, ...] minus_one_scale_buffer = minus_one_scale_allocator.get(); - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; + int64_t minus_one_ne[4] = { src0->ne[0], 1, 1, 1 }; + size_t minus_one_nb[GGML_MAX_DIMS]; minus_one_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } - acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); - int64_t dim = 3; - int64_t* index = new int64_t[src0->ne[0]]; + acl_minus_one_tensor = aclnn_values(ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], minus_one_ne, + GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); + int64_t dim = 3; + int64_t * index = new int64_t[src0->ne[0]]; for (int i = 0; i < src0->ne[0]; i++) { index[i] = i / 2 * 2; } int64_t index_num = src0->ne[0]; - float value = -1; - aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, - index_num, value); + float value = -1; + aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, index_num, value); } else { // roll input: [q0,q1,q2,...] -> // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1] input_roll_buffer = roll_allocator.get(); - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0); + aclTensor * acl_input_roll_tensor = + ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->ne, src0->nb, GGML_MAX_DIMS); + aclTensor * acl_input_tensor = ggml_cann_create_tensor(src0); - int64_t shifts[] = {src0->ne[0] / 2}; - int64_t dims[] = {3}; + int64_t shifts[] = { src0->ne[0] / 2 }; + int64_t dims[] = { 3 }; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, -1, -1, 1, 1,1,...] - minus_one_scale_buffer = minus_one_scale_allocator.get(); - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; + minus_one_scale_buffer = minus_one_scale_allocator.get(); + int64_t minus_one_ne[4] = { src0->ne[0], 1, 1, 1 }; + size_t minus_one_nb[GGML_MAX_DIMS]; minus_one_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } - acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); + acl_minus_one_tensor = aclnn_values(ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], minus_one_ne, + GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); // -1 * first half - int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; - size_t first_half_nb[GGML_MAX_DIMS]; + int64_t first_half_ne[4] = { src0->ne[0] / 2, 1, 1, 1 }; + size_t first_half_nb[GGML_MAX_DIMS]; first_half_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; } - aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( - minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne, - first_half_nb, GGML_MAX_DIMS); - bool inplace = true; - float scale = -1; + aclTensor * acl_first_half_tensor = ggml_cann_create_tensor(minus_one_scale_buffer, ACL_FLOAT, sizeof(float), + first_half_ne, first_half_nb, GGML_MAX_DIMS); + bool inplace = true; + float scale = -1; aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); ggml_cann_release_resources(ctx, acl_first_half_tensor); } @@ -2667,30 +2563,27 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { GGML_ASSERT(n_dims == src0->ne[0]); // input * scale - ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), - ggml_nbytes(src0)); - void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get(); - size_t input_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), ggml_nbytes(src0)); + void * input_roll_mul_scale_buffer = roll_mul_scale_allocator.get(); + size_t input_nb[GGML_MAX_DIMS]; input_nb[0] = ggml_type_size(src0->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { input_nb[i] = input_nb[i - 1] * src0->ne[i - 1]; } - aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor( - input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); - aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); + aclTensor * acl_input_roll_mul_scale_tensor = + ggml_cann_create_tensor(input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); + aclTensor * acl_input_roll_reshape_tensor = + ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->ne, input_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, - acl_input_roll_mul_scale_tensor); + aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, acl_input_roll_mul_scale_tensor); // output - void* output_fp32_buffer; + void * output_fp32_buffer; if (src0->type == GGML_TYPE_F32) { aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, - acl_sin_reshape_tensor); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor); aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); // TODO: ne0 != n_dims in mode2 } else if (src0->type == GGML_TYPE_F16) { @@ -2699,36 +2592,27 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { for (int i = 1; i < GGML_MAX_DIMS; i++) { input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; } - ggml_cann_pool_alloc fp32_allocator1( - ctx.pool(), ggml_nelements(dst) * sizeof(float)); - void* input_fp32_buffer1 = fp32_allocator1.get(); - aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( - input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc fp32_allocator2( - ctx.pool(), ggml_nelements(dst) * sizeof(float)); - void* input_fp32_buffer2 = fp32_allocator2.get(); - aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( - input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); + ggml_cann_pool_alloc fp32_allocator1(ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void * input_fp32_buffer1 = fp32_allocator1.get(); + aclTensor * input_fp32_tensor1 = ggml_cann_create_tensor(input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); + ggml_cann_pool_alloc fp32_allocator2(ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void * input_fp32_buffer2 = fp32_allocator2.get(); + aclTensor * input_fp32_tensor2 = ggml_cann_create_tensor(input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc fp32_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float)); - output_fp32_buffer = fp32_allocator.get(); - aclTensor* output_fp32_tensor = ggml_cann_create_tensor( - output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); + ggml_cann_pool_alloc fp32_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float)); + output_fp32_buffer = fp32_allocator.get(); + aclTensor * output_fp32_tensor = ggml_cann_create_tensor(output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, - input_fp32_tensor2); - aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, - output_fp32_tensor); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, input_fp32_tensor2); + aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, output_fp32_tensor); aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); - ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, - output_fp32_tensor, acl_sin_reshape_tensor, - acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, - acl_input_roll_reshape_tensor, acl_src); + ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, output_fp32_tensor, + acl_sin_reshape_tensor, acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, + acl_input_roll_reshape_tensor, acl_src); } return; #endif @@ -2737,155 +2621,146 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t acl_mode = mode == 0 ? 1 : mode; switch (src0->type) { - case GGML_TYPE_F32: { - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, - acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst); - break; - } - case GGML_TYPE_F16: { - ggml_cann_pool_alloc src_trans_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void* src_trans_buffer = src_trans_allocator.get(); - ggml_cann_pool_alloc dst_trans_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float)); - void* dst_trans_buffer = dst_trans_allocator.get(); - - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + case GGML_TYPE_F32: + { + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_mode, acl_dst); + break; } + case GGML_TYPE_F16: + { + ggml_cann_pool_alloc src_trans_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(float)); + void * src_trans_buffer = src_trans_allocator.get(); + ggml_cann_pool_alloc dst_trans_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void * dst_trans_buffer = dst_trans_allocator.get(); - aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, - GGML_MAX_DIMS); - aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( - dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, - GGML_MAX_DIMS); + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } - aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); + aclTensor * acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), + src0->ne, src_trans_nb, GGML_MAX_DIMS); + aclTensor * acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), + dst->ne, src_trans_nb, GGML_MAX_DIMS); - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, - acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, - acl_dst_trans_tensor); + aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); - aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor); - ggml_cann_release_resources(ctx, acl_src_trans_tensor, - acl_dst_trans_tensor); - break; - } + aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + + ggml_cann_release_resources(ctx, acl_src_trans_tensor, acl_dst_trans_tensor); + break; + } default: GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); break; } - ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, - acl_sin_reshape_tensor, acl_src, acl_dst); + ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_src, acl_dst); } - - void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst); } -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; // stride - int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + int64_t s0 = ((const int32_t *) (dst->op_params))[0]; - aclTensor* acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); - aclTensor* acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); - aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + aclTensor * acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + aclTensor * acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); + aclTensor * acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); int64_t strideVal[1]; - strideVal[0] = s0; - aclIntArray *stride = aclCreateIntArray(strideVal, 1); - int64_t paddingVal[] = {0}; - aclIntArray *padding = aclCreateIntArray(paddingVal, 1); - int64_t dilationVal[] = {1}; - aclIntArray *dilation = aclCreateIntArray(dilationVal, 1); - int8_t cubeMathType = 0; + strideVal[0] = s0; + aclIntArray * stride = aclCreateIntArray(strideVal, 1); + int64_t paddingVal[] = { 0 }; + aclIntArray * padding = aclCreateIntArray(paddingVal, 1); + int64_t dilationVal[] = { 1 }; + aclIntArray * dilation = aclCreateIntArray(dilationVal, 1); + int8_t cubeMathType = 0; #ifdef ASCEND_310P cubeMathType = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, - padding, dilation, true, padding, 1, acl_dst, cubeMathType); + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, padding, dilation, true, padding, + 1, acl_dst, cubeMathType); ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation); } -void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; - aclTensor* acl_input = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_input = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - float alphaValue = 1.0f; - aclScalar* alpha = nullptr; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + float alphaValue = 1.0f; + aclScalar * alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha, - acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha, acl_dst); ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha); } -void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - int64_t reduceDimValue[] = {3}; - aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1); - bool keepDim = true; + int64_t reduceDimValue[] = { 3 }; + aclIntArray * reduceDim = aclCreateIntArray(reduceDimValue, 1); + bool keepDim = true; GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim); } -void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - ggml_tensor * src0 = dst->src[0]; - int32_t *opts = (int32_t *) dst->op_params; - int64_t paddingsArray[2] = {opts[0], opts[1]}; - aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2); +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + int32_t * opts = (int32_t *) dst->op_params; + int64_t paddingsArray[2] = { opts[0], opts[1] }; + aclIntArray * paddings = aclCreateIntArray(paddingsArray, 2); for (int64_t i = 0; i < src0->ne[3]; i++) { - aclTensor* acl_src = ggml_cann_create_tensor( - (char*)src0->data + i * src0->ne[3], - ggml_cann_type_mapping(src0->type), ggml_element_size(src0), - src0->ne, src0->nb, 3); + aclTensor * acl_src = + ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src0->ne, src0->nb, 3); - aclTensor* acl_dst = ggml_cann_create_tensor( - (char*)dst->data + i * src0->ne[3], - ggml_cann_type_mapping(dst->type), ggml_element_size(dst), - dst->ne, dst->nb, 3); + aclTensor * acl_dst = + ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), + ggml_element_size(dst), dst->ne, dst->nb, 3); - GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst); - ggml_cann_release_resources(ctx, acl_src, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } ggml_cann_release_resources(ctx, paddings); } -void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; - aclTensor* acl_self = ggml_cann_create_tensor(src0); - aclTensor* acl_other = ggml_cann_create_tensor(src1); + aclTensor * acl_self = ggml_cann_create_tensor(src0); + aclTensor * acl_other = ggml_cann_create_tensor(src1); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other); @@ -2894,15 +2769,15 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_self, acl_other); } -void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src0); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); - float alphaValue = 0.0f; - aclScalar* alpha = nullptr; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + float alphaValue = 0.0f; + aclScalar * alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst); @@ -2927,7 +2802,7 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ * @note This function assumes floating-point data types and is designed for * MoE architectures, possibly involving sparse expert routing. */ -static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { //dst [M, K, N, 1] ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1] ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1] @@ -2941,36 +2816,42 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* GGML_ASSERT(batch == ids->ne[1]); ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0)); - void* export_ptr = export_allocator.get(); + void * export_ptr = export_allocator.get(); for (int64_t i = 0; i < batch; i++) { - aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]); - aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3); + aclTensor * select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]); + aclTensor * export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3); - int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]}; - size_t select_export_nb[3]; + int64_t select_export_ne[] = { src0->ne[0], src0->ne[1], ids->ne[0] }; + size_t select_export_nb[3]; select_export_nb[0] = src0->nb[0]; - for (int k = 1;k < 3; k++) { - select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1]; + for (int k = 1; k < 3; k++) { + select_export_nb[k] = select_export_nb[k - 1] * select_export_ne[k - 1]; } - aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3); + aclTensor * select_export = + ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + select_export_ne, select_export_nb, 3); GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export); - int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]}; - size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]}; - aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3); + int64_t select_transpose_ne[] = { select_export_ne[1], select_export_ne[0], select_export_ne[2] }; + size_t select_transpose_nb[] = { select_export_nb[1], select_export_nb[0], select_export_nb[2] }; + aclTensor * select_export_transpose = + ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + select_transpose_ne, select_transpose_nb, 3); - int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]}; - size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]}; - aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]); + int64_t active_tensor_ne[] = { src1->ne[0], 1, src1->ne[1] }; + size_t active_tensor_nb[] = { src1->nb[0], src1->nb[1], src1->nb[1] }; + aclTensor * active_tensor = + ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]); - int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]}; - size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]}; - aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]); + int64_t dst_ne[] = { dst->ne[0], 1, dst->ne[1] }; + size_t dst_nb[] = { dst->nb[0], dst->nb[1], dst->nb[1] }; + aclTensor * acl_dst = ggml_cann_create_tensor(dst, dst_ne, dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]); GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2); - ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose); + ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, + select_export_transpose); } } @@ -2997,7 +2878,7 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* * @note This function assumes quantized data types and is designed for * MoE architectures with potential sparse expert routing. */ -static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // TODO: Use aclnnGroupedMatMul //dst [M, K, N, 1] ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] @@ -3007,24 +2888,23 @@ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tens GGML_TENSOR_BINARY_OP_LOCALS // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K std::vector ids_host(ggml_nbytes(ids)); - ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), ACL_MEMCPY_DEVICE_TO_HOST); ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); char * src0_original = (char *) src0->data; char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + char * dst_original = (char *) dst->data; ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + ggml_tensor dst_row = *dst; const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; + float weight_elem_size; if (type == GGML_TYPE_Q4_0) { weight_elem_size = float(sizeof(uint8_t)) / 2; } else if (type == GGML_TYPE_Q8_0) { @@ -3034,18 +2914,18 @@ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tens } // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[0] = weight_elem_size; + src0_row.nb[1] = weight_elem_size * ne00; + src0_row.nb[2] = weight_elem_size * ne00; + src0_row.nb[3] = weight_elem_size * ne00; size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + size_t weight_size = weight_stride * ne02 * ne03; // scale [D, M, 1, 1] -> scale && permute size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; // src1_row [D, 1, 1, 1] -> input src1_row.ne[1] = 1; @@ -3063,11 +2943,11 @@ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tens //create weight for one row ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void* weight_buffer = weight_allocator.alloc(nb02); + void * weight_buffer = weight_allocator.alloc(nb02); for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { for (int64_t id = 0; id < n_ids; id++) { // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); GGML_ASSERT(i02 >= 0 && i02 < n_as); // If B = 1 (broadcast), always use 0; otherwise, use id. @@ -3077,21 +2957,19 @@ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tens int64_t i1 = id; int64_t i2 = i12; - void* src0_tmp_ptr = src0_original + i02*weight_stride; - void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride; - void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; - void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + void * src0_tmp_ptr = src0_original + i02 * weight_stride; + void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; + void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; + void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; // mem cpy - ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE); - void* scale_buffer = (char*)weight_buffer + weight_stride; - ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE); + ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, ACL_MEMCPY_DEVICE_TO_DEVICE); + void * scale_buffer = (char *) weight_buffer + weight_stride; + ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, ACL_MEMCPY_DEVICE_TO_DEVICE); - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; + src0_row.data = weight_buffer; + src1_row.data = src1_tmp_ptr; + dst_row.data = dst_tmp_ptr; dst_row.src[0] = &src0_row; dst_row.src[1] = &src1_row; @@ -3101,7 +2979,7 @@ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tens return; } -void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) { const enum ggml_type type = dst->src[0]->type; switch (type) { case GGML_TYPE_F32: @@ -3118,12 +2996,11 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } } -void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - - ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont) - ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) - ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) - ggml_tensor* src3 = dst->src[3]; // mask, fp16 +void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor * src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor * src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor * src3 = dst->src[3]; // mask, fp16 // B, N, S, D (uncont) -> B, S, N, D (cont) int64_t src0_bsnd_ne[GGML_MAX_DIMS]; @@ -3139,107 +3016,96 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ size_t src2_bsnd_nb[GGML_MAX_DIMS]; memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t)); - auto transpose12 = [](int64_t* ne, size_t* nb) { + auto transpose12 = [](int64_t * ne, size_t * nb) { int64_t ne_tmp = ne[1]; size_t nb_tmp = nb[1]; - ne[1] = ne[2]; - nb[1] = nb[2]; - ne[2] = ne_tmp; - nb[2] = nb_tmp; + ne[1] = ne[2]; + nb[1] = nb[2]; + ne[2] = ne_tmp; + nb[2] = nb_tmp; }; transpose12(src0_bsnd_ne, src0_bsnd_nb); transpose12(src1_bsnd_ne, src1_bsnd_nb); transpose12(src2_bsnd_ne, src2_bsnd_nb); - float maxBias = 0.0f; - float scaleValue = 1.0f; + float maxBias = 0.0f; + float scaleValue = 1.0f; float logitSoftcap = 0.0f; - memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); - memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); - memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); + memcpy(&scaleValue, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&maxBias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logitSoftcap, (float *) dst->op_params + 2, sizeof(float)); - if(logitSoftcap == 0.0f){ + if (logitSoftcap == 0.0f) { size_t faElemSize = sizeof(uint16_t); - auto faDataType = ACL_FLOAT16; //ACL_BF16; + auto faDataType = ACL_FLOAT16; //ACL_BF16; - aclTensor* acl_src0_f16_tensor = nullptr; - aclTensor* acl_src1_f16_tensor = nullptr; - aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor * acl_src0_f16_tensor = nullptr; + aclTensor * acl_src1_f16_tensor = nullptr; + aclTensor * acl_src2_f16_tensor = nullptr; // Step 1: cast the src0 (Query) to fp16 if needed ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); - void* src0_f16_buffer = nullptr; + void * src0_f16_buffer = nullptr; - if(ggml_cann_type_mapping(src0->type) != faDataType){ - aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, - src0_bsnd_nb, GGML_MAX_DIMS); - src0_f16_buffer = src0_f16_allocator.alloc( - ggml_nelements(src0) * faElemSize); + if (ggml_cann_type_mapping(src0->type) != faDataType) { + aclTensor * acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, src0_bsnd_nb, GGML_MAX_DIMS); + src0_f16_buffer = src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize); - int64_t* src0_f16_ne = src0_bsnd_ne; - size_t src0_f16_nb[GGML_MAX_DIMS]; + int64_t * src0_f16_ne = src0_bsnd_ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; src0_f16_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ + for (int i = 1; i < GGML_MAX_DIMS; ++i) { src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; } - acl_src0_f16_tensor = ggml_cann_create_tensor( - src0_f16_buffer, faDataType, faElemSize, - src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS - ); + acl_src0_f16_tensor = ggml_cann_create_tensor(src0_f16_buffer, faDataType, faElemSize, src0_f16_ne, + src0_f16_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); ggml_cann_release_resources(ctx, acl_src0_f32_tensor); - }else{ - acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, - src0_bsnd_nb, GGML_MAX_DIMS); + } else { + acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, src0_bsnd_nb, GGML_MAX_DIMS); } // Step 2: create the acl tensors for src1 (Key), src2 (Value), // and the direct output from FusedInferAttention - acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, - src1_bsnd_nb, GGML_MAX_DIMS); - acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, - src2_bsnd_nb, GGML_MAX_DIMS); + acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp - aclTensor* bcast_pse_tensor = nullptr; + aclTensor * bcast_pse_tensor = nullptr; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); - if(src3 != nullptr){ + if (src3 != nullptr) { // Construct the truncated pse tensor (common for prefill/decode) int64_t trunc_pse_ne[GGML_MAX_DIMS] = { - src3->ne[0], // D - src0->ne[1], // S (number of Q tokens) - src3->ne[2], // mask N - src3->ne[3] // B + src3->ne[0], // D + src0->ne[1], // S (number of Q tokens) + src3->ne[2], // mask N + src3->ne[3] // B }; - size_t* trunc_pse_nb = src3->nb; + size_t * trunc_pse_nb = src3->nb; - aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( - src3->data, ACL_FLOAT16, sizeof(uint16_t), - trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS - ); + aclTensor * acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); int64_t bcast_pse_ne[GGML_MAX_DIMS]; - size_t bcast_pse_nb[GGML_MAX_DIMS]; - bcast_pse_ne[0] = src3->ne[0]; // D - bcast_pse_ne[1] = src0->ne[1]; // S - bcast_pse_ne[2] = src0->ne[2]; // N (num_heads) - bcast_pse_ne[3] = src3->ne[3]; // B + size_t bcast_pse_nb[GGML_MAX_DIMS]; + bcast_pse_ne[0] = src3->ne[0]; // D + bcast_pse_ne[1] = src0->ne[1]; // S + bcast_pse_ne[2] = src0->ne[2]; // N (num_heads) + bcast_pse_ne[3] = src3->ne[3]; // B if (maxBias == 0.0f) { // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2) // Construct the bcast tensor (simulate repeat on the head dimension using stride=0) bcast_pse_nb[0] = sizeof(uint16_t); bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0]; - bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data + bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data bcast_pse_nb[3] = src3->nb[3]; - bcast_pse_tensor = ggml_cann_create_tensor( - src3->data, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS - ); + bcast_pse_tensor = ggml_cann_create_tensor(src3->data, ACL_FLOAT16, sizeof(uint16_t), bcast_pse_ne, + bcast_pse_nb, GGML_MAX_DIMS); ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); } else { @@ -3248,35 +3114,31 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; } - void* bcast_pse_buffer = bcast_pse_allocator.alloc( - ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t) - ); + void * bcast_pse_buffer = + bcast_pse_allocator.alloc(ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS - ); + bcast_pse_tensor = ggml_cann_create_tensor(bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); - int64_t repeats[] = {1, src0->ne[2], 1, 1}; + int64_t repeats[] = { 1, src0->ne[2], 1, 1 }; aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); // alibi // Compute the slope if needed. Derived from ggml_cann_softmax(). - const int64_t n_heads = src0->ne[2]; + const int64_t n_heads = src0->ne[2]; ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t)); - void* slope_buffer = slope_allocator.get(); + void * slope_buffer = slope_allocator.get(); aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16); - int64_t slope_ne[] = {1, 1, n_heads, 1}; - size_t slope_nb[GGML_MAX_DIMS]; + int64_t slope_ne[] = { 1, 1, n_heads, 1 }; + size_t slope_nb[GGML_MAX_DIMS]; slope_nb[0] = sizeof(uint16_t); - for(int i = 1;ine[2]; // N - int64_t numKeyValueHeads = src1->ne[2]; + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) - int64_t preTokens = 65535; - int64_t nextTokens = 65535; - char layout[5] = {'B', 'S', 'N', 'D', 0}; - int64_t sparseMode = 0; - int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; - int64_t blockSize = 0; - int64_t antiquantMode = 0; - bool softmaxLseFlag = false; - int64_t keyAntiquantMode = 0; + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = { 'B', 'S', 'N', 'D', 0 }; + int64_t sparseMode = 0; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; int64_t valueAntiquantMode = 0; GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - aclTensor * fa_dst_tensor = nullptr; - aclTensor * acl_dst_tensor = nullptr; + aclTensor * fa_dst_tensor = nullptr; + aclTensor * acl_dst_tensor = nullptr; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); if (dst->type == GGML_TYPE_F32) { - void* out_f16_buffer = out_f16_allocator.alloc( - ggml_nelements(dst) * faElemSize); + void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - int64_t* out_f16_ne = src0_bsnd_ne; - size_t out_f16_nb[GGML_MAX_DIMS]; + int64_t * out_f16_ne = src0_bsnd_ne; + size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ + for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } - fa_dst_tensor = ggml_cann_create_tensor( - out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS - ); - } - else { + fa_dst_tensor = + ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); + } else { fa_dst_tensor = ggml_cann_create_tensor(dst); } - GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, - acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v - bcast_pse_tensor, nullptr, // pse, mask - nullptr, nullptr, // actSeqLen, actSeqLenkv - nullptr, nullptr, // deqScale1, quantScale1 - nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 - nullptr, nullptr, // antiquantScale, antiquantOffset - nullptr, // blockTable - nullptr, nullptr, // qPadSize, kvPadSize - nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset - nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset - nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen - numHeads, scaleValue, // heads, scaleValue - preTokens, nextTokens, // preTokens, nextTokens - layout, // inputLayout - numKeyValueHeads, // numKVHeads - sparseMode, innerPrecise, // sparseMode, innerPrecise - blockSize, antiquantMode, // blockSize, antiquantMode - softmaxLseFlag, // softmaxLseFlag - keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - fa_dst_tensor, // attentionOut - nullptr // softmaxLse + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, acl_q_tensor, acl_k_tensor_list, + acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + fa_dst_tensor, // attentionOut + nullptr // softmaxLse ); if (dst->type == GGML_TYPE_F32) { // Step 6: post-processing, permute and cast to f32 - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclTensor * acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); } - ggml_cann_release_resources(ctx, acl_src0_f16_tensor, - acl_k_tensor_list, - acl_v_tensor_list, - fa_dst_tensor, - acl_dst_tensor, - bcast_pse_tensor); + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_k_tensor_list, acl_v_tensor_list, fa_dst_tensor, + acl_dst_tensor, bcast_pse_tensor); } else { GGML_ABORT("Function is not implemented."); diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100755 new mode 100644 index 5c510cc9..ec7455af --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -62,7 +62,7 @@ * @param dst The ggml tensor representing the destination, which op is * GGML_OP_REPEAT and specifies the desired dimensions. */ -void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN @@ -82,7 +82,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result of the Leaky ReLU * activation is stored, which op is `GGML_OP_LEAKY_RELU` */ -void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Concatenates multiple tensors along a specified dimension using the @@ -97,7 +97,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @attention tensorList length should be 2 and the dimension using for concat * default to 1. */ -void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Generates a sequence of evenly spaced values within a specified @@ -113,7 +113,7 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst); * `start`, 'stop' and 'step' are in dst->op_params and dst->op is * `GGML_OP_ARANGE`. */ -void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies a clamp operation to the elements of a ggml tensor using the @@ -131,7 +131,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the clamped values will be stored. * dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params. */ -void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Scales the elements of a ggml tensor by a constant factor using the @@ -148,7 +148,7 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the scaled values will be stored. * dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params. */ -void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Sorts the elements of a ggml tensor and returns the indices that @@ -163,7 +163,7 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the sorted indices will be stored. * dst->op is `GGML_OP_ARGSORT`. */ -void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the Layer Normalization for a ggml tensor using the CANN @@ -185,7 +185,7 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the normalized values will be stored. * @attention `Var` defaults to dst->ne[0]. */ -void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the Group Normalization for a ggml tensor using the CANN @@ -209,7 +209,7 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); * * @attention eps defaults to 1e-6f. */ -void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the accumulation of tensors using the CANN backend. @@ -228,7 +228,7 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the accumulated values will be stored. * `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`. */ -void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the sum of elements along the last dimension of a ggml tensor @@ -244,7 +244,7 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); * * @attention `reduce_dims` defaults to 3, which means the last dimension. */ -void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the sum of elements in a ggml tensor. @@ -258,7 +258,7 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); * */ -void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using @@ -274,8 +274,7 @@ void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the upsampled values will be stored. * dst->op is `GGML_OP_UPSCALE`. */ -void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, - ggml_tensor* dst); +void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Pads a ggml tensor to match the dimensions of the destination tensor @@ -290,7 +289,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, * @param dst The destination tensor, which specifies the target dimensions for * padding. dst->op is `GGML_OP_PAD`. */ -void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Executes a 2D pooling operation on a ggml tensor using the CANN @@ -307,7 +306,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor on which the pooling operation is to be * performed. dst->op is `GGML_OP_POOL_2D`. */ -void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Duplicates a ggml tensor using the CANN backend. @@ -326,7 +325,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst); * different shape and dst is no-contiguous. * @note: This func need to simplify. */ -void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor @@ -348,7 +347,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the normalized values will be stored. * dst->op is `GGML_OP_RMS_NORM`. */ -void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies a diagonal mask to the tensor with a specified value. @@ -363,7 +362,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); * `GGML_OP_DIAG_MASK` * @param value The value to use for masking. */ -void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value); +void ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value); /** * @brief Performs an image-to-column transformation on the input tensor. @@ -378,7 +377,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float * @param dst The destination tensor that stores the result of the operation. * dst->op is `GGML_OP_IM2COL`. */ -void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes time step embeddings using sine and cosine functions. @@ -392,10 +391,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result of the embedding operation * will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`. */ -void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst); // @see ggml_cann_dup. -void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the softmax activation with optional masking. @@ -417,7 +416,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result will be stored. dst->op is * `GGML_OP_SOFTMAX`. */ -void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Extracts specific rows from a tensor based on indices. @@ -429,7 +428,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param ctx The backend CANN context for executing operations. * @param dst The destination tensor where the extracted rows will be stored. */ -void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Writes specific rows into a tensor at positions specified by indices. @@ -441,7 +440,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param ctx The backend CANN context for executing operations. * @param dst The destination tensor where the specified rows will be updated. */ -void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Executes matrix multiplication for the given tensor. @@ -454,7 +453,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor for storing the result of the matrix * multiplication. dst->op is `GGML_OP_MUL_MAT`. */ -void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor. @@ -477,7 +476,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @note The function currently does not support cases where the freq_scale is * not equal 1. */ -void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the index of the maximum value along the specified dimension @@ -492,7 +491,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the indices of the maximum values will * be stored. dst->op is `GGML_OP_ARGMAX`. */ -void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Adds two tensors element-wise and stores the result in a destination @@ -509,8 +508,10 @@ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param acl_src1 The second source tensor. * @param acl_dst The destination tensor where the result will be stored. */ -void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, - aclTensor* acl_src1, aclTensor* acl_dst = nullptr); +void aclnn_add(ggml_backend_cann_context & ctx, + aclTensor * acl_src0, + aclTensor * acl_src1, + aclTensor * acl_dst = nullptr); /** * @brief Sub two tensors element-wise and stores the result in a destination @@ -527,8 +528,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, * @param acl_src1 The second source tensor. * @param acl_dst The destination tensor where the result will be stored. */ -void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, - aclTensor* acl_src1, aclTensor* acl_dst = nullptr); +void aclnn_sub(ggml_backend_cann_context & ctx, + aclTensor * acl_src0, + aclTensor * acl_src1, + aclTensor * acl_dst = nullptr); /** * @brief Performs element-wise multiplication of two tensors and stores the @@ -546,8 +549,10 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, * @param acl_other The second tensor for element-wise multiplication. * @param acl_dst The destination tensor where the result will be stored. */ -void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst = nullptr); +void aclnn_mul(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_other, + aclTensor * acl_dst = nullptr); /** * @brief Matrix division, optionally in-place. @@ -567,8 +572,10 @@ void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param inplace Flag indicating whether to perform the operation in-place on * `acl_src`. */ -void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst = nullptr); +void aclnn_div(ggml_backend_cann_context & ctx, + aclTensor * acl_src, + aclTensor * acl_other, + aclTensor * acl_dst = nullptr); /** * @brief Applies element-wise cosine function to the elements of a tensor. @@ -584,8 +591,7 @@ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param acl_dst The destination tensor where the cosine results will be * stored. */ -void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst); +void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst); /** * @brief Applies element-wise sine function to the elements of a tensor. @@ -602,8 +608,7 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param acl_src The source tensor on which the sine function will be applied. * @param acl_dst The destination tensor where the sine results will be stored. */ -void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst); +void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst); /** * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one @@ -621,8 +626,12 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1. * @param acl_dst Output pointer to the created ACL tensor corresponding to dst. */ -void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, - aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst); +void bcast_shape(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + aclTensor ** acl_src0, + aclTensor ** acl_src1, + aclTensor ** acl_dst); /** * @brief Computes the 1D transposed convolution (deconvolution) of a ggml @@ -637,7 +646,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, * @param dst The destination tensor where the transposed convolution result * will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`. */ -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor @@ -662,7 +671,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds * @param dst The destination tensor where the ELU-activated result will be stored. * dst->op is expected to be `GGML_OP_ELU`. */ -void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Computes the mean of a ggml tensor element-wise using the CANN backend. @@ -677,7 +686,7 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the mean result will be stored. * dst->op is expected to be `GGML_OP_MEAN`. */ -void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies 1D reflect padding to a ggml tensor using the CANN backend. @@ -692,7 +701,7 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the padded result will be stored. * dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`. */ -void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Counts the number of equal elements in two ggml tensors using the CANN backend. @@ -708,7 +717,7 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result will be stored. * dst->op is expected to be `GGML_OP_COUNT_EQUAL`. */ -void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Applies the Step activation function to a ggml tensor using the CANN backend. @@ -723,7 +732,7 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result will be stored. * dst->op is expected to be `GGML_OP_STEP`. */ -void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Performs the Flash Attention extended operator using the CANN backend. @@ -738,59 +747,46 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); * @param dst The destination tensor where the result will be stored. * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. */ -void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst); /* * @brief A generic wrapper for ACL resources with custom deleter support. */ -using any_acl_resource = std::unique_ptr>; +using any_acl_resource = std::unique_ptr>; /** * @brief Trait structure used to define how to destroy a given ACL resource type. * * @tparam T ACL resource type. */ -template -struct acl_resource_traits; +template struct acl_resource_traits; /** * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. */ -template<> -struct acl_resource_traits { - static void destroy(void* p) { - ACL_CHECK(aclDestroyTensor(static_cast(p))); - } +template <> struct acl_resource_traits { + static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast(p))); } }; /** * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. */ -template<> -struct acl_resource_traits { - static void destroy(void* p) { - ACL_CHECK(aclDestroyIntArray(static_cast(p))); - } +template <> struct acl_resource_traits { + static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast(p))); } }; /** * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. */ -template<> -struct acl_resource_traits { - static void destroy(void* p) { - ACL_CHECK(aclDestroyScalar(static_cast(p))); - } +template <> struct acl_resource_traits { + static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast(p))); } }; /** * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. */ -template<> -struct acl_resource_traits { - static void destroy(void* p) { - ACL_CHECK(aclDestroyTensorList(static_cast(p))); - } +template <> struct acl_resource_traits { + static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast(p))); } }; /** @@ -800,14 +796,8 @@ struct acl_resource_traits { * @param ptr Raw pointer to ACL resource. * @return any_acl_resource Smart pointer that handles destruction. */ -template -any_acl_resource make_acl_resource(T* ptr) { - return any_acl_resource( - static_cast(ptr), - [](void* p) { - acl_resource_traits::destroy(p); - } - ); +template any_acl_resource make_acl_resource(T * ptr) { + return any_acl_resource(static_cast(ptr), [](void * p) { acl_resource_traits::destroy(p); }); } /** @@ -817,8 +807,7 @@ any_acl_resource make_acl_resource(T* ptr) { * @param vec Target vector to hold ACL resources. * @param args Raw pointers to ACL resources. */ -template -void register_acl_resources(std::vector& vec, Args*... args) { +template void register_acl_resources(std::vector & vec, Args *... args) { (vec.emplace_back(make_acl_resource(args)), ...); } @@ -826,39 +815,36 @@ void register_acl_resources(std::vector& vec, Args*... args) { * @brief Task class that wraps the execution of an aclnn function call. */ class aclnn_task : public cann_task { - public: - aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr, - uint64_t workspace_size, aclOpExecutor * executor, - aclrtStream stream) : - aclnn_func_(aclnn_func), - workspace_addr_(workspace_addr), - workspace_size_(workspace_size), - executor_(executor), - stream_(stream) {} - virtual void run_task() override { - ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); - } - private: - aclnn_func_t aclnn_func_; - void * workspace_addr_; - uint64_t workspace_size_; - aclOpExecutor * executor_; - aclrtStream stream_; + public: + aclnn_task(aclnn_func_t aclnn_func, + void * workspace_addr, + uint64_t workspace_size, + aclOpExecutor * executor, + aclrtStream stream) : + aclnn_func_(aclnn_func), + workspace_addr_(workspace_addr), + workspace_size_(workspace_size), + executor_(executor), + stream_(stream) {} + + virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); } + private: + aclnn_func_t aclnn_func_; + void * workspace_addr_; + uint64_t workspace_size_; + aclOpExecutor * executor_; + aclrtStream stream_; }; /** * @brief Task class that releases ACL resources after usage. */ class release_resource_task : public cann_task { -public: - release_resource_task(std::vector&& resources){ - resource_ = std::move(resources); - } + public: + release_resource_task(std::vector && resources) { resource_ = std::move(resources); } - virtual void run_task() override { - resource_.clear(); - } -private: + virtual void run_task() override { resource_.clear(); } + private: std::vector resource_; }; @@ -866,38 +852,40 @@ private: * @brief Task class for performing asynchronous memory copy operations. */ class async_memcpy_task : public cann_task { -public: - async_memcpy_task(void* dst, const void* src, size_t size, - aclrtMemcpyKind kind, aclrtStream stream) - : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {} + public: + async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) : + dst_(dst), + src_(src), + size_(size), + kind_(kind), + stream_(stream) {} - virtual void run_task() override { - ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); - } -private: - void* dst_; - const void* src_; - size_t size_; + virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); } + private: + void * dst_; + const void * src_; + size_t size_; aclrtMemcpyKind kind_; - aclrtStream stream_; + aclrtStream stream_; }; /** * @brief Task class for performing asynchronous memory set operations. */ class async_memset_task : public cann_task { - public: - async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream) - : buffer_(buffer), size_(size), value_(value), stream_(stream) {} + public: + async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) : + buffer_(buffer), + size_(size), + value_(value), + stream_(stream) {} - virtual void run_task() override { - ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); - } - private: - void* buffer_; - size_t size_; - int32_t value_; - aclrtStream stream_; + virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); } + private: + void * buffer_; + size_t size_; + int32_t value_; + aclrtStream stream_; }; /** @@ -918,25 +906,24 @@ class async_memset_task : public cann_task { * same stream are executed in queue order. */ -#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ - do { \ - uint64_t workspaceSize = 0; \ - aclOpExecutor * executor; \ - void * workspaceAddr = nullptr; \ - ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\ - /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ - if (workspaceSize > 0) { \ - ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ - workspaceAddr = workspace_allocator.get(); \ - } \ - if (CTX.async_mode) { \ - auto task = \ - std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize, \ - executor, CTX.stream()); \ - CTX.task_queue.submit_task(std::move(task)); \ - } else { \ - ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\ - } \ +#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + if (CTX.async_mode) { \ + auto task = \ + std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \ + CTX.task_queue.submit_task(std::move(task)); \ + } else { \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ + } \ } while (0) /** @@ -947,11 +934,10 @@ class async_memset_task : public cann_task { * @param ctx Backend context which manages task submission and async mode. * @param args Pointers to ACL resources to be released. */ -template -void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) { +template void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) { std::vector resources; register_acl_resources(resources, std::forward(args)...); - if(ctx.async_mode) { + if (ctx.async_mode) { auto task = std::make_unique(std::move(resources)); ctx.task_queue.submit_task(std::move(task)); } @@ -966,8 +952,11 @@ void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... arg * @param len Size of memory to copy (in bytes). * @param kind Type of memory copy (host-to-device, device-to-host, etc). */ -inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst, - const void * src, size_t len, aclrtMemcpyKind kind) { +inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, + void * dst, + const void * src, + size_t len, + aclrtMemcpyKind kind) { if (ctx.async_mode) { auto task = std::make_unique(dst, const_cast(src), len, kind, ctx.stream()); ctx.task_queue.submit_task(std::move(task)); @@ -976,8 +965,11 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst, } } -inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst, - const void * src, size_t len, aclrtMemcpyKind kind) { +inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, + void * dst, + const void * src, + size_t len, + aclrtMemcpyKind kind) { if (ctx->async_mode) { auto task = std::make_unique(dst, const_cast(src), len, kind, ctx->stream()); ctx->task_queue.submit_task(std::move(task)); @@ -994,8 +986,7 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst, * @param size Size of the memory buffer (in bytes). * @param value Value to set in the buffer. */ -inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, - size_t size, int value) { +inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) { if (ctx.async_mode) { auto task = std::make_unique(buffer, size, value, ctx.stream()); ctx.task_queue.submit_task(std::move(task)); @@ -1029,7 +1020,7 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe * @param dst The destination tensor where the expert-weighted token outputs are stored. * Expected to be of shape [M, K, N, 1]. */ -void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Check whether a tensor is a weight tensor for matrix multiplication. @@ -1041,20 +1032,14 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); * * @param tensor Pointer to the target ggml_tensor object (const-qualified). */ -static bool is_matmul_weight(const ggml_tensor* tensor) { - std::string name = ggml_get_name(tensor); - static const std::unordered_set weight_suffixes{ - "output.weight", - "attn_q.weight", - "attn_k.weight", - "attn_v.weight", - "attn_output.weight", - "ffn_gate.weight", - "ffn_up.weight", - "ffn_down.weight" - }; +static bool is_matmul_weight(const ggml_tensor * tensor) { + std::string name = ggml_get_name(tensor); + static const std::unordered_set weight_suffixes{ "output.weight", "attn_q.weight", + "attn_k.weight", "attn_v.weight", + "attn_output.weight", "ffn_gate.weight", + "ffn_up.weight", "ffn_down.weight" }; - for (const auto& suffix : weight_suffixes) { + for (const auto & suffix : weight_suffixes) { if (name.find(suffix) != std::string::npos) { return true; } @@ -1078,14 +1063,13 @@ static bool is_matmul_weight(const ggml_tensor* tensor) { * @param ctx The CANN backend context used to manage execution and resources. * @param dst The destination tensor. */ -template -void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; +template void ggml_cann_binary_op(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; - aclTensor* acl_src0; - aclTensor* acl_src1; - aclTensor* acl_dst; + aclTensor * acl_src0; + aclTensor * acl_src1; + aclTensor * acl_dst; // Need bcast bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); @@ -1094,7 +1078,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst); } - /** * @brief Applies a unary operation to an input tensor using the CANN backend. * @@ -1107,12 +1090,12 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param ctx The CANN backend context for managing resources and execution. * @param dst The destination tensor. Its src[0] is treated as the input tensor. */ -template - void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +template +void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor * acl_src = ggml_cann_create_tensor(src); + aclTensor * acl_dst = ggml_cann_create_tensor(dst); unary_op(ctx, acl_src, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst); @@ -1138,9 +1121,9 @@ template * * @see GGML_CANN_CALL_OP_UNARY */ -void ggml_cann_op_unary( - std::function unary_op, - ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_op_unary(std::function unary_op, + ggml_backend_cann_context & ctx, + ggml_tensor * dst); /** * @brief Applies a gated (GLU-style) unary operation using the CANN backend. @@ -1172,9 +1155,9 @@ void ggml_cann_op_unary( * * @see GGML_CANN_CALL_OP_UNARY_GATED */ -void ggml_cann_op_unary_gated( - std::function unary_op, - ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_op_unary_gated(std::function unary_op, + ggml_backend_cann_context & ctx, + ggml_tensor * dst); /** * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary. @@ -1197,16 +1180,13 @@ void ggml_cann_op_unary_gated( * @see ggml_cann_op_unary * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ - do { \ - auto lambda = [](ggml_backend_cann_context& ctx, \ - aclTensor* acl_src, \ - aclTensor* acl_dst) { \ - GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ - }; \ - ggml_cann_op_unary(lambda, ctx, dst); \ - } \ - while (0) +#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_op_unary(lambda, ctx, dst); \ + } while (0) /** * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated. @@ -1229,15 +1209,12 @@ void ggml_cann_op_unary_gated( * @see ggml_cann_op_unary_gated * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ - do { \ - auto lambda = [](ggml_backend_cann_context& ctx, \ - aclTensor* acl_src, \ - aclTensor* acl_dst) { \ - GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ - }; \ - ggml_cann_op_unary_gated(lambda, ctx, dst); \ - } \ - while (0) +#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_op_unary_gated(lambda, ctx, dst); \ + } while (0) #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100755 new mode 100644 index debbcadc..e87dbcf3 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -44,7 +44,7 @@ #include "../include/ggml.h" #include "../ggml-impl.h" -#define MATRIX_ROW_PADDING 512 +#define MATRIX_ROW_PADDING 512 #define GGML_CANN_MAX_STREAMS 8 /** @@ -56,8 +56,7 @@ * @param line The line number at which the error occurred. * @param msg The error message. */ -[[noreturn]] void ggml_cann_error(const char* stmt, const char* func, - const char* file, int line, const char* msg); +[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg); /** * @brief Checks the result of a CANN function call and invokes the error @@ -89,25 +88,24 @@ struct ggml_cann_device_info { * @brief Information about a single CANN device. */ struct cann_device_info { - int cc; /**< Compute capability. */ + int cc; /**< Compute capability. */ size_t smpb; /**< Maximum shared memory per block. */ - bool vmm; /**< Virtual memory support. */ + bool vmm; /**< Virtual memory support. */ size_t vmm_granularity; /**< Granularity of virtual memory. */ size_t total_vram; /**< Total video RAM available on the device. */ }; - cann_device_info devices[GGML_CANN_MAX_DEVICES] = - {}; /**< Array of CANN device information. */ + cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */ }; -const ggml_cann_device_info& ggml_cann_info(); +const ggml_cann_device_info & ggml_cann_info(); -void ggml_cann_set_device(int32_t device); +void ggml_cann_set_device(int32_t device); int32_t ggml_cann_get_device(); -std::optional get_env(const std::string& name); -bool parse_bool(const std::string& value); -int parse_integer(const std::string& value); +std::optional get_env(const std::string & name); +bool parse_bool(const std::string & value); +int parse_integer(const std::string & value); /** * @brief Abstract base class for memory pools used by CANN. @@ -126,7 +124,7 @@ struct ggml_cann_pool { * will be stored. * @return Pointer to the allocated memory block. */ - virtual void* alloc(size_t size, size_t* actual_size) = 0; + virtual void * alloc(size_t size, size_t * actual_size) = 0; /** * @brief Frees a previously allocated memory block. @@ -136,16 +134,16 @@ struct ggml_cann_pool { * @note Note that all CANN opertors are running async. Make sure memory is * still avaiable before this operator finished. */ - virtual void free(void* ptr, size_t size) = 0; + virtual void free(void * ptr, size_t size) = 0; }; /** * @brief RAII wrapper for managing memory allocations from a CANN memory pool. */ struct ggml_cann_pool_alloc { - ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */ - void* ptr = nullptr; /**< Pointer to the allocated memory block. */ - size_t actual_size = 0; /**< Actual size of the allocated memory block. */ + ggml_cann_pool * pool = nullptr; /**< Pointer to the memory pool. */ + void * ptr = nullptr; /**< Pointer to the allocated memory block. */ + size_t actual_size = 0; /**< Actual size of the allocated memory block. */ /** * @brief Default constructor. @@ -156,16 +154,14 @@ struct ggml_cann_pool_alloc { * @brief Constructor that initializes the memory pool. * @param pool Reference to the memory pool. */ - explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {} + explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {} /** * @brief Constructor that initializes the memory pool and allocates memory. * @param pool Reference to the memory pool. * @param size Size of the memory block to allocate. */ - ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) { - alloc(size); - } + ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); } /** * @brief Destructor that frees the allocated memory block. @@ -181,7 +177,7 @@ struct ggml_cann_pool_alloc { * @param size Size of the memory block to allocate. * @return Pointer to the allocated memory block. */ - void* alloc(size_t size) { + void * alloc(size_t size) { GGML_ASSERT(pool != nullptr); GGML_ASSERT(ptr == nullptr); ptr = pool->alloc(size, &this->actual_size); @@ -194,7 +190,7 @@ struct ggml_cann_pool_alloc { * @param size Size of the memory block to allocate. * @return Pointer to the allocated memory block. */ - void* alloc(ggml_cann_pool& pool, size_t size) { + void * alloc(ggml_cann_pool & pool, size_t size) { this->pool = &pool; return alloc(size); } @@ -203,25 +199,25 @@ struct ggml_cann_pool_alloc { * @brief Gets the pointer to the allocated memory block. * @return Pointer to the allocated memory block. */ - void* get() { return ptr; } + void * get() { return ptr; } // Deleted copy constructor - ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete; + ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete; // Deleted move constructor - ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete; + ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete; // Deleted copy assignment operator - ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete; + ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete; // Deleted move assignment operator - ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete; + ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete; }; /** * @brief Function pointer type for ACLNN operator calls. */ -using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream); +using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream); /** * @brief Base class for all CANN tasks to be submitted to the task queue. @@ -229,7 +225,7 @@ using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStrea * Users should override the run_task() method with actual task logic. */ class cann_task { -public: + public: virtual void run_task() {} }; @@ -237,16 +233,20 @@ public: * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances. */ class cann_task_queue { -public: + public: /** * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device. * * @param capacity Queue capacity. Must be a power of 2. * @param device Target device ID (used for context setting). */ - explicit cann_task_queue(size_t capacity, int32_t device) - : buffer_(capacity), capacity_(capacity), head_(0), tail_(0), - running_(false), device_(device) { + explicit cann_task_queue(size_t capacity, int32_t device) : + buffer_(capacity), + capacity_(capacity), + head_(0), + tail_(0), + running_(false), + device_(device) { GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2"); mask_ = capacity_ - 1; } @@ -257,7 +257,7 @@ public: * @param item Unique pointer to the task. * @return true if the task was successfully enqueued, false if the queue was full. */ - bool enqueue(std::unique_ptr&& item) { + bool enqueue(std::unique_ptr && item) { size_t next_tail = (tail_ + 1) & mask_; if (next_tail == head_) { @@ -276,17 +276,16 @@ public: * * @param task Task to be submitted. */ - void submit_task(std::unique_ptr&& task) { - while(!enqueue(std::move(task))) { + void submit_task(std::unique_ptr && task) { + while (!enqueue(std::move(task))) { std::this_thread::yield(); continue; } if (!running_) { running_ = true; - thread_ = std::thread(&cann_task_queue::execute, this); + thread_ = std::thread(&cann_task_queue::execute, this); } - } /** @@ -309,7 +308,7 @@ public: } } -private: + private: /** * @brief Worker thread function that continuously dequeues and executes tasks. */ @@ -317,7 +316,7 @@ private: ggml_cann_set_device(device_); while (running_) { - if(head_ == tail_) { + if (head_ == tail_) { std::this_thread::yield(); continue; } @@ -330,24 +329,24 @@ private: } std::vector> buffer_; - const size_t capacity_; - size_t mask_; - size_t head_; - size_t tail_; - bool running_; - std::thread thread_; - int32_t device_; + const size_t capacity_; + size_t mask_; + size_t head_; + size_t tail_; + bool running_; + std::thread thread_; + int32_t device_; }; #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { // dst tensor - void * node_address; + void * node_address; int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; // src tensor - void * src_address[GGML_MAX_SRC]; + void * src_address[GGML_MAX_SRC]; int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; @@ -376,13 +375,11 @@ struct ggml_cann_graph { * move existing graphs to the front (most recently used), and clear the cache. */ struct ggml_cann_graph_lru_cache { - size_t capacity; /**< Maximum number of graphs in the cache. */ + size_t capacity; /**< Maximum number of graphs in the cache. */ - std::list cache_list; /**< List storing cached graphs as raw pointers. */ + std::list cache_list; /**< List storing cached graphs as raw pointers. */ - ggml_cann_graph_lru_cache() { - capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); - } + ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); } /** * @brief Push a new graph to the front of the cache. @@ -390,11 +387,11 @@ struct ggml_cann_graph_lru_cache { * @param new_node Pointer to the new ggml_cann_graph to cache. * Ownership is transferred to the cache (cache will delete it). */ - void push(ggml_cann_graph* new_node) { + void push(ggml_cann_graph * new_node) { if (cache_list.size() >= capacity) { - ggml_cann_graph* old = cache_list.back(); + ggml_cann_graph * old = cache_list.back(); cache_list.pop_back(); - delete old; // free the old graph + delete old; // free the old graph } cache_list.push_front(new_node); } @@ -403,7 +400,7 @@ struct ggml_cann_graph_lru_cache { * @brief Move an existing graph to the front of the cache. * @param node Pointer to the ggml_cann_graph to move. */ - void move_to_front(ggml_cann_graph* node) { + void move_to_front(ggml_cann_graph * node) { cache_list.remove(node); cache_list.push_front(node); } @@ -421,92 +418,89 @@ struct ggml_cann_graph_lru_cache { /** * @brief Destructor that clears the cache and frees all cached graphs. */ - ~ggml_cann_graph_lru_cache() { - clear(); - } + ~ggml_cann_graph_lru_cache() { clear(); } }; #endif // USE_ACL_GRAPH struct ggml_cann_rope_cache { ~ggml_cann_rope_cache() { - if(theta_scale_cache != nullptr) { + if (theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if(sin_cache != nullptr) { + if (sin_cache != nullptr) { ACL_CHECK(aclrtFree(sin_cache)); } - if(cos_cache != nullptr) { + if (cos_cache != nullptr) { ACL_CHECK(aclrtFree(cos_cache)); } } - void* theta_scale_cache = nullptr; + void * theta_scale_cache = nullptr; int64_t theta_scale_length = 0; // sin/cos cache, used only to accelerate first layer on each device - void* sin_cache = nullptr; - void* cos_cache = nullptr; - int64_t position_length = 0; + void * sin_cache = nullptr; + void * cos_cache = nullptr; + int64_t position_length = 0; // Properties to check before reusing the sincos cache - bool cached = false; - float ext_factor = 0.0f; - float theta_scale = 0.0f; - float freq_scale = 0.0f; - float attn_factor = 0.0f; - bool is_neox = false; + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; }; struct ggml_cann_tensor_cache { ~ggml_cann_tensor_cache() { - if(cache != nullptr) { + if (cache != nullptr) { ACL_CHECK(aclrtFree(cache)); } } - void* cache = nullptr; - int64_t size = 0; + void * cache = nullptr; + int64_t size = 0; }; /** * @brief Context for managing CANN backend operations. */ struct ggml_backend_cann_context { - int32_t device; /**< Device ID. */ - std::string name; /**< Name of the device. */ - std::string description; /**< Description of the device. */ - aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ + int32_t device; /**< Device ID. */ + std::string name; /**< Name of the device. */ + std::string description; /**< Description of the device. */ + aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. ggml_cann_graph_lru_cache graph_lru_cache; - bool acl_graph_mode = true; + bool acl_graph_mode = true; #endif - cann_task_queue task_queue; - bool async_mode; + cann_task_queue task_queue; + bool async_mode; // Rope Cache - ggml_cann_rope_cache rope_cache; + ggml_cann_rope_cache rope_cache; // Constant Pool ggml_cann_tensor_cache rms_norm_one_tensor_cache; ggml_cann_tensor_cache rms_norm_zero_tensor_cache; - aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ + aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */ /** * @brief Constructor for initializing the context with a given device. * @param device Device ID. */ - explicit ggml_backend_cann_context(int device) - : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) { + explicit ggml_backend_cann_context(int device) : + device(device), + name("CANN" + std::to_string(device)), + task_queue(1024, device) { ggml_cann_set_device(device); description = aclrtGetSocName(); async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); - GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, - device, async_mode ? "ON" : "OFF"); + GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); #ifdef USE_ACL_GRAPH acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on")); - GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", - __func__, device, - acl_graph_mode ? "GRAPH" : "EAGER", - acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); + GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER", + acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); #endif } @@ -549,8 +543,7 @@ struct ggml_backend_cann_context { aclrtStream stream() { return stream(0); } // TODO: each stream should have a memory pool. - std::unique_ptr - mem_pool; /**< Memory pool for the device. */ + std::unique_ptr mem_pool; /**< Memory pool for the device. */ /** * @brief Create a new memory pool for a given device. @@ -563,7 +556,7 @@ struct ggml_backend_cann_context { * @brief Get or create the memory pool for the context. * @return Reference to the memory pool. */ - ggml_cann_pool& pool() { + ggml_cann_pool & pool() { if (mem_pool == nullptr) { mem_pool = new_pool_for_device(device); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100755 new mode 100644 index ad1adba6..8bd5449f --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -56,14 +56,12 @@ * @param line The line number where the error occurred. * @param msg The error message. */ -[[noreturn]] void ggml_cann_error(const char* stmt, const char* func, - const char* file, int line, const char* msg) { +[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int32_t id = -1; aclrtGetDevice(&id); GGML_LOG_ERROR("CANN error: %s\n", msg); - GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, - file, line); + GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); GGML_LOG_ERROR(" %s\n", stmt); // abort with GGML_ASSERT to get a stack trace GGML_ABORT("CANN error"); @@ -79,7 +77,7 @@ void ggml_cann_set_device(const int32_t device) { aclrtGetDevice(¤t_device); if (device == current_device) { - return; + return; } ACL_CHECK(aclrtSetDevice(device)); } @@ -99,9 +97,11 @@ int32_t ggml_cann_get_device() { * @brief Get the value of the specified environment variable (name). * if not empty, return a std::string object */ -std::optional get_env(const std::string& name) { - const char* val = std::getenv(name.c_str()); - if (!val) return std::nullopt; +std::optional get_env(const std::string & name) { + const char * val = std::getenv(name.c_str()); + if (!val) { + return std::nullopt; + } std::string res = std::string(val); std::transform(res.begin(), res.end(), res.begin(), ::tolower); return res; @@ -110,8 +110,8 @@ std::optional get_env(const std::string& name) { /** * @brief Verify whether the environment variable is a valid value. */ -bool parse_bool(const std::string& value) { - std::unordered_set valid_values = {"on", "1", "yes", "y", "enable", "true"}; +bool parse_bool(const std::string & value) { + std::unordered_set valid_values = { "on", "1", "yes", "y", "enable", "true" }; return valid_values.find(value) != valid_values.end(); } @@ -125,7 +125,7 @@ bool parse_bool(const std::string& value) { * @param value The string to parse. * @return The parsed integer, or 0 if conversion fails. */ -int parse_integer(const std::string& value) { +int parse_integer(const std::string & value) { try { return std::stoi(value); } catch (...) { @@ -144,11 +144,10 @@ int parse_integer(const std::string& value) { static ggml_cann_device_info ggml_cann_init() { ggml_cann_device_info info = {}; - aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count); + aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count); if (err != ACL_SUCCESS) { - GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", - __func__, aclGetRecentErrMsg()); + GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg()); return info; } @@ -156,16 +155,15 @@ static ggml_cann_device_info ggml_cann_init() { for (int id = 0; id < info.device_count; ++id) { aclrtPhysicalMemProp prop = {}; - prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; - prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; - prop.memAttr = ACL_HBM_MEM_HUGE; - prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = id; - prop.reserve = 0; - err = aclrtMemGetAllocationGranularity( - &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, - &info.devices[id].vmm_granularity); - info.devices[id].vmm = err == ACL_SUCCESS; + prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; + prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + prop.memAttr = ACL_HBM_MEM_HUGE; + prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = id; + prop.reserve = 0; + err = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, + &info.devices[id].vmm_granularity); + info.devices[id].vmm = err == ACL_SUCCESS; size_t free, total; ggml_backend_cann_get_device_memory(id, &free, &total); @@ -185,7 +183,7 @@ static ggml_cann_device_info ggml_cann_init() { * * @return A reference to the structure containing the device information. */ -const ggml_cann_device_info& ggml_cann_info() { +const ggml_cann_device_info & ggml_cann_info() { static ggml_cann_device_info info = ggml_cann_init(); return info; } @@ -205,7 +203,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { /** * @brief The minimum free margin for a buffer. */ - static const size_t min_free_margin = 1ull << 20; // 1MB + static const size_t min_free_margin = 1ull << 20; // 1MB /** * @brief The alignment for buffer allocation. @@ -226,22 +224,18 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @brief Structure representing a CANN buffer. */ struct ggml_cann_buffer { - void* ptr = nullptr; ///< Pointer to the buffer. - size_t size = 0; ///< Size of the buffer. - std::chrono::steady_clock::time_point last_used; ///< Last used time. + void * ptr = nullptr; ///< Pointer to the buffer. + size_t size = 0; ///< Size of the buffer. + std::chrono::steady_clock::time_point last_used; ///< Last used time. - bool operator>(const ggml_cann_buffer& other) const { - return size > other.size; - } + bool operator>(const ggml_cann_buffer & other) const { return size > other.size; } }; /** * @brief Array of CANN buffers in the pool. */ - std::unordered_map buffer_pool; - std::priority_queue, - std::greater<>> free_buffers ; + std::unordered_map buffer_pool; + std::priority_queue, std::greater<>> free_buffers; /** * @brief Total size of all buffers in the pool. @@ -262,7 +256,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { */ ~ggml_cann_pool_buf_prio() { ggml_cann_set_device(device); - for (auto& [b_ptr, b_size] : buffer_pool) { + for (auto & [b_ptr, b_size] : buffer_pool) { aclrtFree(b_ptr); pool_size -= b_size; } @@ -278,14 +272,14 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * the allocated buffer. * @return A pointer to the allocated buffer. */ - void* alloc(size_t size, size_t* actual_size) override { + void * alloc(size_t size, size_t * actual_size) override { size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } - void* ptr = nullptr; - auto now = std::chrono::steady_clock::now(); + void * ptr = nullptr; + auto now = std::chrono::steady_clock::now(); std::vector free_buffers_rest; free_buffers_rest.reserve(free_buffers.size()); @@ -298,24 +292,22 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { const size_t margin = b.size - size; if (margin <= max_reuse_margin) { *actual_size = b.size; - ptr = b.ptr; + ptr = b.ptr; #ifdef DEBUG_CANN_MALLOC GGML_LOG_INFO( "cann pool[%d]: reused %p, " "pool_size = %5u MB, " "size = %5u MB, " "margin = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(margin, 1048576) / 1048576)); #endif break; } } - bool should_clean = !disable_clean && - b.size > min_free_margin && + bool should_clean = !disable_clean && b.size > min_free_margin && std::chrono::duration_cast(now - b.last_used).count() > 100; if (should_clean) { // free the buffer if the size is needed to be freed @@ -327,20 +319,20 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { "cann pool[%d]: clean %p, " "pool_size = %5u MB, " "size = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576)); #endif continue; } free_buffers_rest.push_back(b); } - for (ggml_cann_buffer &b : free_buffers_rest) { + for (ggml_cann_buffer & b : free_buffers_rest) { free_buffers.push(std::move(b)); } #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); + GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, + (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576)); #endif if (ptr != nullptr) { return ptr; @@ -356,8 +348,8 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { "cann pool[%d]: allocate %p, " "pool_size = %5u MB, " "size = %5u MB\n", - device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); + device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(size, 1048576) / 1048576)); #endif buffer_pool.emplace(ptr, size); return ptr; @@ -369,7 +361,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @param ptr Pointer to the buffer to free. * @param size Size of the buffer to free. */ - void free(void* ptr, size_t size) override { + void free(void * ptr, size_t size) override { GGML_UNUSED(size); auto it = buffer_pool.find(ptr); if (it == buffer_pool.end()) { @@ -377,13 +369,12 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { } auto now = std::chrono::steady_clock::now(); - free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); + free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now }); #ifdef DEBUG_CANN_MALLOC GGML_LOG_INFO( "cann pool[%d]: return %p, " "pool_size = %5u MB\n", - device, ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); + device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576)); #endif } }; @@ -402,7 +393,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { /** * @brief The minimum free margin for a buffer. */ - static const size_t min_free_margin = 1ull << 20; // 1MB + static const size_t min_free_margin = 1ull << 20; // 1MB /** * @brief The alignment for buffer allocation. @@ -428,10 +419,10 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @brief Structure representing a CANN buffer. */ struct ggml_cann_buffer { - void* ptr = nullptr; ///< Pointer to the buffer memory. - size_t size = 0; ///< Size of the buffer. - bool used = false; ///< Whether the buffer is currently in use. - std::chrono::steady_clock::time_point last_used; ///< Last used time. + void * ptr = nullptr; ///< Pointer to the buffer memory. + size_t size = 0; ///< Size of the buffer. + bool used = false; ///< Whether the buffer is currently in use. + std::chrono::steady_clock::time_point last_used; ///< Last used time. }; /** @@ -459,7 +450,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { ~ggml_cann_pool_buf() { ggml_cann_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { aclrtFree(b.ptr); pool_size -= b.size; @@ -476,18 +467,18 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * the allocated buffer. * @return A pointer to the allocated buffer. */ - void* alloc(size_t size, size_t* actual_size) override { + void * alloc(size_t size, size_t * actual_size) override { size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } - void* ptr = nullptr; - auto now = std::chrono::steady_clock::now(); + void * ptr = nullptr; + auto now = std::chrono::steady_clock::now(); int i = 0; for (; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_buffer & b = buffer_pool[i]; if (b.ptr == nullptr) { break; } @@ -499,25 +490,23 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { const size_t margin = b.size - size; if (margin <= max_reuse_margin) { *actual_size = b.size; - b.used = true; - ptr = b.ptr; + b.used = true; + ptr = b.ptr; #ifdef DEBUG_CANN_MALLOC GGML_LOG_INFO( "cann pool[%d]: reused %p, " "pool_size = %5u MB, " "size = %5u MB, " "margin = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(margin, 1048576) / 1048576)); #endif break; } } - bool should_clean = !disable_clean && - b.size > min_free_margin && + bool should_clean = !disable_clean && b.size > min_free_margin && std::chrono::duration_cast(now - b.last_used).count() > 100; if (should_clean) { // free the buffer if the size is needed to be freed @@ -528,9 +517,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { "cann pool[%d]: clean %p, " "pool_size = %5u MB, " "size = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576)); #endif b.ptr = nullptr; } @@ -541,13 +529,13 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { if (i < MAX_BUFFERS) { // allocate a new buffer if no buffer can be reused - ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_buffer & b = buffer_pool[i]; ggml_cann_set_device(device); ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); pool_size += size; *actual_size = size; - b.size = size; - b.used = true; + b.size = size; + b.used = true; if (i >= MAX_BUFFERS - 8) { GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device); } @@ -556,9 +544,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { "cann pool[%d]: allocate %p, " "pool_size = %5u MB, " "size = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576)); #endif return b.ptr; } @@ -572,21 +559,20 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param ptr Pointer to the buffer to free. * @param size Size of the buffer to free. */ - void free(void* ptr, size_t size) override { + void free(void * ptr, size_t size) override { GGML_UNUSED(size); for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_buffer & b = buffer_pool[i]; if (b.ptr != ptr) { continue; } - b.used = false; + b.used = false; b.last_used = std::chrono::steady_clock::now(); #ifdef DEBUG_CANN_MALLOC GGML_LOG_INFO( "cann pool[%d]: return %p, " "pool_size = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); + device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576)); #endif return; } @@ -614,7 +600,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { /** * @brief Pointer to the start of the virtual memory pool. */ - void* pool_addr = 0; + void * pool_addr = 0; /** * @brief Amount of virtual memory used in the pool. @@ -639,7 +625,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { /** * @brief Offsets for the mapped memory regions. */ - std::vector map_offsets; + std::vector map_offsets; /** * @brief Constructor to initialize the buffer pool with virtual memory for @@ -647,11 +633,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * * @param device The device ID to associate with this buffer pool. */ - explicit ggml_cann_pool_vmm(int device) - : device(device) { - auto dev = ggml_cann_info().devices[device]; + explicit ggml_cann_pool_vmm(int device) : device(device) { + auto dev = ggml_cann_info().devices[device]; granularity = dev.vmm_granularity; - max_size = dev.total_vram; + max_size = dev.total_vram; } /** @@ -659,10 +644,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ ~ggml_cann_pool_vmm() { if (pool_addr != 0) { - for (auto& offset : map_offsets) { + for (auto & offset : map_offsets) { ACL_CHECK(aclrtUnmapMem(offset)); } - for (auto& handle : handles) { + for (auto & handle : handles) { ACL_CHECK(aclrtFreePhysical(handle)); } ACL_CHECK(aclrtReleaseMemAddress(pool_addr)); @@ -677,11 +662,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * the allocated buffer. * @return A pointer to the allocated buffer. */ - void* alloc(size_t size, size_t* actual_size) override { + void * alloc(size_t size, size_t * actual_size) override { // round up the allocation size to the alignment to ensure that all // allocations are aligned for all data types const size_t alignment = 128; - size = GGML_PAD(size, alignment); + size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } @@ -691,53 +676,51 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { if (size > avail) { // round up to the next multiple of the granularity size_t reserve_size = size - avail; - reserve_size = GGML_PAD(reserve_size, granularity); + reserve_size = GGML_PAD(reserve_size, granularity); GGML_ASSERT(pool_size + reserve_size <= max_size); // allocate more physical memory aclrtPhysicalMemProp prop = {}; - prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; - prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; - prop.memAttr = ACL_HBM_MEM_HUGE; - prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - prop.reserve = 0; + prop.handleType = ACL_MEM_HANDLE_TYPE_NONE; + prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + prop.memAttr = ACL_HBM_MEM_HUGE; + prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.reserve = 0; aclrtDrvMemHandle handle; ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0)); // reserve virtual address space (if not already reserved) if (pool_addr == 0) { - ACL_CHECK(aclrtReserveMemAddress( - &pool_addr, max_size, 0, NULL, 1)); + ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1)); } // map at the end of the pool - ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0, - handle, 0)); + ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0)); handles.push_back(handle); - map_offsets.push_back((char*)pool_addr + pool_size); + map_offsets.push_back((char *) pool_addr + pool_size); // add to the pool pool_size += reserve_size; #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", - device, (unsigned long long) (pool_size/1024/1024), - (unsigned long long) (reserve_size/1024/1024)); + GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", device, + (unsigned long long) (pool_size / 1024 / 1024), + (unsigned long long) (reserve_size / 1024 / 1024)); #endif } GGML_ASSERT(pool_addr != 0); - void* ptr = (void*)((char*)pool_addr + pool_used); + void * ptr = (void *) ((char *) pool_addr + pool_used); *actual_size = size; pool_used += size; #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, - (unsigned long long)size, (unsigned long long)ptr); + GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, + (unsigned long long) ptr); #endif return ptr; } @@ -748,16 +731,16 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @param ptr Pointer to the buffer to free. * @param size Size of the buffer to free. */ - void free(void* ptr, size_t size) override { + void free(void * ptr, size_t size) override { #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, - (unsigned long long)size, (unsigned long long)ptr); + GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, + (unsigned long long) ptr); #endif pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used)); } }; @@ -769,8 +752,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @param device The device ID for which to create the pool. * @return A unique pointer to the created CANN pool. */ -std::unique_ptr ggml_backend_cann_context::new_pool_for_device( - int device) { +std::unique_ptr ggml_backend_cann_context::new_pool_for_device(int device) { std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); if (mem_pool_type == "prio") { @@ -795,9 +777,8 @@ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID. */ struct ggml_backend_cann_buffer_context { - int32_t device; ///< The device ID associated with this buffer context. - void* dev_ptr = - nullptr; ///< Pointer to the device memory allocated for the buffer. + int32_t device; ///< The device ID associated with this buffer context. + void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. /** * @brief Constructor to initialize the CANN buffer context. @@ -805,9 +786,7 @@ struct ggml_backend_cann_buffer_context { * @param device The device ID associated with this buffer context. * @param dev_ptr Pointer to the device memory allocated for the buffer. */ - ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr) - : device(device), - dev_ptr(dev_ptr) {} + ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {} /** * @brief Destructor to free the device memory allocated for the buffer. @@ -825,8 +804,8 @@ struct ggml_backend_cann_buffer_context { * @return true if the buffer is a CANN buffer, false otherwise. */ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); -static bool ggml_backend_buffer_is_cann( - ggml_backend_buffer_t buffer) { + +static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { return ggml_backend_buft_is_cann(buffer->buft); } @@ -838,10 +817,8 @@ static bool ggml_backend_buffer_is_cann( * * @param buffer The CANN buffer to free. */ -static void ggml_backend_cann_buffer_free_buffer( - ggml_backend_buffer_t buffer) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; +static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; delete ctx; } @@ -854,10 +831,8 @@ static void ggml_backend_cann_buffer_free_buffer( * @param buffer The CANN buffer whose base pointer is to be retrieved. * @return A pointer to the base of the device memory allocated for the buffer. */ -static void* ggml_backend_cann_buffer_get_base( - ggml_backend_buffer_t buffer) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; +static void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; return ctx->dev_ptr; } @@ -874,21 +849,17 @@ static void* ggml_backend_cann_buffer_get_base( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, - const void* src, - void* dst) { +static void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) { + int64_t n_elems = ggml_nelements(tensor); + int64_t groups = n_elems / QK4_0; + size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK4_0; - size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; - - uint8_t* quant_offset = (uint8_t*)dst; - uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); + uint8_t * quant_offset = (uint8_t *) dst; + uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes); for (int i = 0; i < groups; i++) { - const block_q4_0* group = - (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0)); - *scale_offset = group->d; + const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0)); + *scale_offset = group->d; scale_offset++; // 0-15 @@ -907,8 +878,7 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, } // put (uint4b_t -8) into int4b_t - for (quant_offset = (uint8_t*)dst; - quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) { + for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) { (*quant_offset) ^= 0x88; } } @@ -926,29 +896,27 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where the Q4.0 formatted data * will be stored. */ -static void ggml_backend_cann_transform_back_q4_0( - const ggml_tensor* tensor, void* src, void* dst) { +static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) { + int64_t n_elems = ggml_nelements(tensor); + int64_t groups = n_elems / QK4_0; + size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK4_0; - size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; + uint8_t * quant_offset = (uint8_t *) src; + uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes); - uint8_t* quant_offset = (uint8_t*)src; - uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes); - - for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) { + for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) { (*quant_offset) ^= 0x88; } - quant_offset = (uint8_t*)src; + quant_offset = (uint8_t *) src; for (int i = 0; i < groups; i++) { - block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0)); - group->d = *scale_offset; + block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0)); + group->d = *scale_offset; scale_offset++; // 0-15 for (int j = 0; j < QK4_0 / 2; j += 2) { - group->qs[j] = ((*quant_offset) & 0x0F); + group->qs[j] = ((*quant_offset) & 0x0F); group->qs[j + 1] = ((*quant_offset) >> 4); quant_offset++; } @@ -975,20 +943,17 @@ static void ggml_backend_cann_transform_back_q4_0( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, - const void* src, - void* dst) { - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK8_0; - size_t quant_bytes = n_elems * sizeof(uint8_t); +static void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) { + int64_t n_elems = ggml_nelements(tensor); + int64_t groups = n_elems / QK8_0; + size_t quant_bytes = n_elems * sizeof(uint8_t); - uint8_t* quant_offset = (uint8_t*)dst; - uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); + uint8_t * quant_offset = (uint8_t *) dst; + uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes); for (int i = 0; i < groups; i++) { - const block_q8_0* group = - (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0)); - *scale_offset = group->d; + const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0)); + *scale_offset = group->d; scale_offset++; size_t group_quant_size = QK8_0 * sizeof(uint8_t); memcpy(quant_offset, group->qs, group_quant_size); @@ -1009,19 +974,17 @@ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where the Q8.0 formatted data * will be stored. */ -static void ggml_backend_cann_transform_back_q8_0( - const ggml_tensor* tensor, const void* src, void* dst) { - int64_t n_elems = ggml_nelements(tensor); - int64_t groups = n_elems / QK8_0; - size_t quant_bytes = n_elems * sizeof(uint8_t); +static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) { + int64_t n_elems = ggml_nelements(tensor); + int64_t groups = n_elems / QK8_0; + size_t quant_bytes = n_elems * sizeof(uint8_t); - const uint8_t* quant_offset = (const uint8_t*)src; - const uint16_t* scale_offset = - (const uint16_t*)((const char*)src + quant_bytes); + const uint8_t * quant_offset = (const uint8_t *) src; + const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes); for (int i = 0; i < groups; i++) { - block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0)); - group->d = *scale_offset; + block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0)); + group->d = *scale_offset; scale_offset++; size_t group_quant_size = QK8_0 * sizeof(uint8_t); memcpy(group->qs, quant_offset, group_quant_size); @@ -1041,8 +1004,7 @@ static void ggml_backend_cann_transform_back_q8_0( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -static void ggml_backend_cann_transform(ggml_tensor* tensor, - const void* src, void* dst) { +static void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) { switch (tensor->type) { case GGML_TYPE_Q4_0: ggml_backend_cann_transform_q4_0(tensor, src, dst); @@ -1067,8 +1029,7 @@ static void ggml_backend_cann_transform(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where transformed tensor data * will be stored. */ -static void ggml_backend_cann_transform_back( - const ggml_tensor* tensor, void* src, void* dst) { +static void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) { switch (tensor->type) { case GGML_TYPE_Q4_0: ggml_backend_cann_transform_back_q4_0(tensor, src, dst); @@ -1109,8 +1070,7 @@ static bool need_transform(ggml_type type) { * @param buffer The CANN buffer from which to initialize the tensor. * @param tensor Pointer to the tensor to be initialized. */ -static enum ggml_status ggml_backend_cann_buffer_init_tensor( - ggml_backend_buffer_t buffer, ggml_tensor* tensor) { +static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { if (tensor->view_src != NULL && tensor->view_offs == 0) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; @@ -1121,13 +1081,11 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor( if (ggml_is_quantized(tensor->type)) { // Initialize padding to 0 to avoid possible NaN values size_t original_size = ggml_nbytes(tensor); - size_t padded_size = - ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); if (padded_size > original_size && tensor->view_src == nullptr) { size_t memset_size = padded_size - original_size; - ACL_CHECK(aclrtMemset((char*)tensor->data + original_size, - memset_size, 0, memset_size)); + ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size)); } } return GGML_STATUS_SUCCESS; @@ -1141,8 +1099,8 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor( * designed to be used with a global array, one per device. */ struct ggml_cann_nz_workspace { - void* ptr; // Pointer to allocated device buffer - size_t allocated; // Size of currently allocated buffer in bytes + void * ptr; // Pointer to allocated device buffer + size_t allocated; // Size of currently allocated buffer in bytes /** * @brief Constructor. Initializes the workspace with no allocated memory. @@ -1158,7 +1116,7 @@ struct ggml_cann_nz_workspace { void clear() { if (ptr) { ACL_CHECK(aclrtFree(ptr)); - ptr = nullptr; + ptr = nullptr; allocated = 0; } } @@ -1185,7 +1143,7 @@ struct ggml_cann_nz_workspace { * * @return Pointer to the allocated buffer, or nullptr if not allocated. */ - void* get() const { return ptr; } + void * get() const { return ptr; } }; /** @@ -1207,19 +1165,17 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) { - aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, - tensor->nb, 2, ACL_FORMAT_ND, offset); - uint64_t workspaceSize = 0; - aclOpExecutor *executor; +static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) { + aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); + uint64_t workspaceSize = 0; + aclOpExecutor * executor; // TransMatmulWeight - ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, - &workspaceSize, &executor)); + ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor)); // Avoid frequent malloc/free of the workspace. g_nz_workspaces[device].realloc(workspaceSize); - void* g_nz_workspace = g_nz_workspaces[device].get(); + void * g_nz_workspace = g_nz_workspaces[device].get(); ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr)); ACL_CHECK(aclDestroyTensor(weightTransposed)); @@ -1238,11 +1194,12 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) * @param offset Offset in the source data from where to start copying. * @param size Size of the data to be copied, in bytes. */ -static void ggml_backend_cann_buffer_set_tensor( - ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, - size_t offset, size_t size) { - ggml_backend_cann_buffer_context *ctx = - (ggml_backend_cann_buffer_context *)buffer->context; +static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); // TODO: refer to cann(#6017), it use thread's default stream. @@ -1252,20 +1209,17 @@ static void ggml_backend_cann_buffer_set_tensor( // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size, - ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) { + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); weight_format_to_nz(tensor, offset, ctx->device); } } else { - void *transform_buffer = malloc(size); + void * transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); - ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, - transform_buffer, size, - ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); free(transform_buffer); } } @@ -1283,22 +1237,20 @@ static void ggml_backend_cann_buffer_set_tensor( * @param offset Offset in the destination buffer where to start copying. * @param size Size of the data to be copied, in bytes. */ -static void ggml_backend_cann_buffer_get_tensor( - ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, - size_t offset, size_t size) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; +static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST)); } else { - void* transform_buffer = malloc(size); - ACL_CHECK(aclrtMemcpy(transform_buffer, size, - (char*)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST)); + void * transform_buffer = malloc(size); + ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST)); ggml_backend_cann_transform_back(tensor, transform_buffer, data); free(transform_buffer); } @@ -1317,19 +1269,17 @@ static void ggml_backend_cann_buffer_get_tensor( * @param dst Pointer to the destination tensor where the data will be copied. * @return true if the copy operation succeeded, false otherwise. */ -static bool ggml_backend_cann_buffer_cpy_tensor( - ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) { +static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { if (ggml_backend_buffer_is_cann(src->buffer)) { - ggml_backend_cann_buffer_context* src_ctx = - (ggml_backend_cann_buffer_context*)src->buffer->context; - ggml_backend_cann_buffer_context* dst_ctx = - (ggml_backend_cann_buffer_context*)buffer->context; + ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context; + ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context; size_t memcpy_size = ggml_nbytes(src); // Same device. if (src_ctx->device == dst_ctx->device) { - ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size, - (const char*)src->data, memcpy_size, + ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE)); return true; } else { @@ -1339,13 +1289,11 @@ static bool ggml_backend_cann_buffer_cpy_tensor( #endif // Different device but can access by peer. int32_t canAccessPeer = 0; - ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, - dst_ctx->device)); + ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device)); if (canAccessPeer) { ggml_cann_set_device(src_ctx->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0)); - ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size, - (const char*)src->data, memcpy_size, + ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE)); return true; } @@ -1363,10 +1311,8 @@ static bool ggml_backend_cann_buffer_cpy_tensor( * @param buffer The CANN buffer to be cleared. * @param value The value to which each byte in the buffer will be set. */ -static void ggml_backend_cann_buffer_clear( - ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; +static void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size)); @@ -1396,9 +1342,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { * buffer type. */ struct ggml_backend_cann_buffer_type_context { - int32_t - device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ + int32_t device; /**< Device identifier associated with the buffer context. */ + std::string name; /**< Name associated with the buffer context. */ }; /** @@ -1410,10 +1355,8 @@ struct ggml_backend_cann_buffer_type_context { * @param buft Pointer to the buffer type context. * @return Const pointer to the C-style string containing the name. */ -static const char* ggml_backend_cann_buffer_type_name( - ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context* buft_ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; +static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; return buft_ctx->name.c_str(); } @@ -1428,34 +1371,27 @@ static const char* ggml_backend_cann_buffer_type_name( * @param size Size in bytes of the buffer to allocate. * @return Pointer to the allocated buffer, or nullptr if allocation fails. */ -static ggml_backend_buffer_t -ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, - size_t size) { - ggml_backend_cann_buffer_type_context* buft_ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; +static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; ggml_cann_set_device(buft_ctx->device); const size_t alignment = 128; - size = GGML_PAD(size, alignment); + size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } - void* dev_ptr; + void * dev_ptr; aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_SUCCESS) { - GGML_LOG_ERROR( - "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", - __func__, size / 1024.0 / 1024.0, buft_ctx->device, - aclGetRecentErrMsg()); + GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__, + size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg()); return nullptr; } - ggml_backend_cann_buffer_context* ctx = - new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr); + ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr); - return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, - ctx, size); + return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size); } /** @@ -1470,8 +1406,7 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, * @return The alignment requirement in bytes (fixed at 128 bytes for CANN * buffers). */ -static size_t ggml_backend_cann_buffer_type_get_alignment( - ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 128; GGML_UNUSED(buft); @@ -1491,10 +1426,10 @@ static size_t ggml_backend_cann_buffer_type_get_alignment( * @return The total allocation size in bytes required for the tensor in the * CANN buffer. */ -static size_t ggml_backend_cann_buffer_type_get_alloc_size( - ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { - size_t size = ggml_nbytes(tensor); - int64_t ne0 = tensor->ne[0]; +static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; // Only check env once. static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); @@ -1507,19 +1442,17 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size( // size += (line_size_align_32 - line_size); if (ggml_is_quantized(tensor->type)) { if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size( - tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - } else if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) { + } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { // NZ format weight are not support quantized yet. // If ND tensor transform to NZ, size may changed. - int64_t shape[] = {tensor->ne[1], tensor->ne[0]}; + int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - const aclIntArray *acl_shape = aclCreateIntArray(shape, 2); - size_t new_size; - ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, - ggml_cann_type_mapping(tensor->type), &new_size)); + const aclIntArray * acl_shape = aclCreateIntArray(shape, 2); + size_t new_size; + ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size)); ACL_CHECK(aclDestroyIntArray(acl_shape)); size = std::max(size, new_size); } @@ -1560,17 +1493,15 @@ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface * @return A pointer to the buffer type interface for the specified device, or * nullptr if the device index is out of range. */ -ggml_backend_buffer_type_t -ggml_backend_cann_buffer_type(int32_t device) { - static std::mutex mutex; +ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) { + static std::mutex mutex; std::lock_guard lock(mutex); if (device >= ggml_backend_cann_get_device_count()) { return nullptr; } - static ggml_backend_buffer_type - ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES]; + static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES]; static bool ggml_backend_cann_buffer_type_initialized = false; @@ -1580,8 +1511,7 @@ ggml_backend_cann_buffer_type(int32_t device) { /* .iface = */ ggml_backend_cann_buffer_type_interface, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i), /* .context = */ - new ggml_backend_cann_buffer_type_context{ - i, "CANN" + std::to_string(i)}, + new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i) }, }; } ggml_backend_cann_buffer_type_initialized = true; @@ -1645,16 +1575,16 @@ static void * ggml_cann_host_malloc(size_t size) { } const size_t alignment = 128; - size = GGML_PAD(size, alignment); + size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } - void * hostPtr = nullptr; - aclError err = aclrtMallocHost((void **) &hostPtr, size); + void * hostPtr = nullptr; + aclError err = aclrtMallocHost((void **) &hostPtr, size); if (err != ACL_SUCCESS) { - GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, - size / 1024.0 / 1024.0, aclGetRecentErrMsg()); + GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, + aclGetRecentErrMsg()); return nullptr; } return hostPtr; @@ -1667,7 +1597,8 @@ static void * ggml_cann_host_malloc(size_t size) { * @param size Size in bytes of the host buffer to allocate. * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails. */ -static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { void * hostPtr = ggml_cann_host_malloc(size); if (hostPtr == nullptr) { @@ -1676,8 +1607,8 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm } ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size); - buffer->buft = buft; - buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free; + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free; return buffer; } @@ -1691,14 +1622,15 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() { static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = { /* .iface = */ { - /* .get_name = */ ggml_backend_cann_host_buffer_type_name, - /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_name = */ ggml_backend_cann_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, - /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, - }, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0), + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ + ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0), /* .context = */ nullptr, }; @@ -1718,8 +1650,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() { * stored. * @return true if the computation was successful; false otherwise. */ -static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, - struct ggml_tensor* dst) { +static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) { switch (dst->op) { case GGML_OP_REPEAT: ggml_cann_repeat(ctx, dst); @@ -1765,14 +1696,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_UNARY_OP_SILU: GGML_CANN_CALL_OP_UNARY(Silu); break; - case GGML_UNARY_OP_GELU_QUICK: { - auto lambda = [](ggml_backend_cann_context& ctx, - aclTensor* acl_src, - aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary(lambda, ctx, dst); - } break; + case GGML_UNARY_OP_GELU_QUICK: + { + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary(lambda, ctx, dst); + } + break; case GGML_UNARY_OP_TANH: GGML_CANN_CALL_OP_UNARY(Tanh); break; @@ -1817,14 +1748,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_GLU_OP_SWIGLU: GGML_CANN_CALL_OP_UNARY_GATED(Silu); break; - case GGML_GLU_OP_GEGLU_QUICK: { - auto lambda = [](ggml_backend_cann_context& ctx, - aclTensor* acl_src, - aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary_gated(lambda, ctx, dst); - } break; + case GGML_GLU_OP_GEGLU_QUICK: + { + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(lambda, ctx, dst); + } + break; default: return false; } @@ -1956,9 +1887,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, * @param backend Pointer to the CANN backend structure. * @return A pointer to a constant string representing the backend name. */ -static const char* ggml_backend_cann_name(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; +static const char * ggml_backend_cann_name(ggml_backend_t backend) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; return cann_ctx->name.c_str(); } @@ -1972,8 +1902,7 @@ static const char* ggml_backend_cann_name(ggml_backend_t backend) { * @param backend Pointer to the CANN backend structure to be freed. */ static void ggml_backend_cann_free(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtResetDevice(cann_ctx->device)); @@ -1981,7 +1910,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) { delete backend; } - /** * @brief Sets tensor data asynchronously in the CANN backend. * @@ -1994,21 +1922,17 @@ static void ggml_backend_cann_free(ggml_backend_t backend) { * @param size Size of the data to copy in bytes. */ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, - ggml_tensor *tensor, - const void *data, - size_t offset, - size_t size) { - ggml_backend_cann_context *cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_buffer_t buf = - tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && - "unsupported buffer type"); + GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type"); GGML_ASSERT(!ggml_is_quantized(tensor->type)); - ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size, - ACL_MEMCPY_HOST_TO_DEVICE); + ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE); } /** @@ -2022,21 +1946,18 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, * @param offset Offset in bytes within the host data. * @param size Size of the data to copy in bytes. */ -static void ggml_backend_cann_get_tensor_async( - ggml_backend_t backend, const ggml_tensor *tensor, void *data, - size_t offset, size_t size) { - ggml_backend_cann_context *cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_buffer_t buf = - tensor->view_src ? tensor->view_src->buffer : tensor->buffer; +static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && - "unsupported buffer type"); + GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type"); GGML_ASSERT(!ggml_is_quantized(tensor->type)); - ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST); - + ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST); } /** @@ -2052,28 +1973,23 @@ static void ggml_backend_cann_get_tensor_async( * @param dst Pointer to the destination tensor to copy data to. * @return true if the copy operation succeeds, false otherwise. */ -static bool ggml_backend_cann_cpy_tensor_async( - ggml_backend_t backend_src, ggml_backend_t backend_dst, - const ggml_tensor* src, ggml_tensor* dst) { - GGML_ASSERT(ggml_backend_is_cann(backend_src) || - ggml_backend_is_cann(backend_dst)); +static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src, + ggml_backend_t backend_dst, + const ggml_tensor * src, + ggml_tensor * dst) { + GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst)); - GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src)); + GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src)); - if (!ggml_backend_buffer_is_cann(src->buffer) || - !ggml_backend_buffer_is_cann(dst->buffer)) { + if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { return false; } - ggml_backend_buffer_t buf_src = - src->view_src ? src->view_src->buffer : src->buffer; - ggml_backend_buffer_t buf_dst = - dst->view_src ? dst->view_src->buffer : dst->buffer; + ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - ggml_backend_cann_context* cann_ctx_src = - (ggml_backend_cann_context*)backend_src->context; - ggml_backend_cann_context* cann_ctx_dst = - (ggml_backend_cann_context*)backend_dst->context; + ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context; + ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context; size_t copy_size = ggml_nbytes(dst); if (copy_size == 0) { @@ -2084,17 +2000,14 @@ static bool ggml_backend_cann_cpy_tensor_async( // TODO: Support 310p P2P copy return false; #endif - ggml_backend_cann_buffer_context* buf_ctx_src = - (ggml_backend_cann_buffer_context*)buf_src->context; - ggml_backend_cann_buffer_context* buf_ctx_dst = - (ggml_backend_cann_buffer_context*)buf_dst->context; + ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context; + ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context; GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device); GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device); int32_t canAccessPeer = 0; - ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, - cann_ctx_dst->device)); + ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device)); if (!canAccessPeer) { return false; } @@ -2106,8 +2019,7 @@ static bool ggml_backend_cann_cpy_tensor_async( // wait for task_queue empty to keep task order. cann_ctx_src->task_queue.wait(); - ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, + ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); // record event on src stream after the copy // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream @@ -2122,8 +2034,7 @@ static bool ggml_backend_cann_cpy_tensor_async( ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream())); } else { // src and dst are on the same backend - ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, + ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_dst->stream())); } @@ -2139,8 +2050,7 @@ static bool ggml_backend_cann_cpy_tensor_async( * @param backend Pointer to the CANN backend structure to synchronize. */ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; cann_ctx->task_queue.wait(); ggml_cann_set_device(cann_ctx->device); ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); @@ -2168,16 +2078,14 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { * @param cann_ctx The CANN backend context containing the graph cache. * @param cgraph The current ggml computation graph. */ -static void add_lru_matched_graph_node_properties( - ggml_backend_cann_context * cann_ctx, - ggml_cgraph * cgraph) { +static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache). ggml_cann_graph * new_graph = new ggml_cann_graph(); new_graph->ggml_graph_properties.resize(cgraph->n_nodes); for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { ggml_tensor * node = cgraph->nodes[node_idx]; - auto & prop = new_graph->ggml_graph_properties[node_idx]; + auto & prop = new_graph->ggml_graph_properties[node_idx]; prop.node_address = node->data; prop.node_op = node->op; @@ -2214,11 +2122,9 @@ static void add_lru_matched_graph_node_properties( * @param graph_node_properties The stored properties of a CANN graph node. * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. */ -static bool ggml_graph_node_has_matching_properties( - ggml_tensor * node, - ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_VIEW) { +static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, + ggml_graph_node_properties * graph_node_properties) { + if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) { return false; } @@ -2237,8 +2143,7 @@ static bool ggml_graph_node_has_matching_properties( for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i]) { - if (node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_VIEW) { + if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) { return false; } @@ -2280,8 +2185,8 @@ static bool ggml_graph_node_has_matching_properties( * @return true if a matching cached graph exists; false otherwise. */ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache; - for (auto &graph_ptr : lru_cache.cache_list) { + ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache; + for (auto & graph_ptr : lru_cache.cache_list) { // Skip graphs with a different number of nodes. if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { continue; @@ -2320,21 +2225,24 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * * @param use_cann_graph Whether to use CANN graph execution. * @param cann_graph_update_required Whether graph capture is needed due to graph changes. */ -static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, - bool & use_cann_graph, bool & cann_graph_update_required) { +static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, + ggml_cgraph * cgraph, + bool & use_cann_graph, + bool & cann_graph_update_required) { #ifdef USE_ACL_GRAPH - ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); + ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); if (use_cann_graph && cann_graph_update_required) { ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } -#endif // USE_ACL_GRAPH +#endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. if (!use_cann_graph || cann_graph_update_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || + node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } @@ -2347,7 +2255,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx } #ifdef USE_ACL_GRAPH - if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture + if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); } @@ -2355,10 +2263,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx // Execute graph ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream())); } -#endif // USE_ACL_GRAPH +#endif // USE_ACL_GRAPH } - /** * @brief Computes a computational graph using a CANN backend. * @@ -2371,10 +2278,8 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation * completes successfully, otherwise an appropriate error status. */ -static enum ggml_status ggml_backend_cann_graph_compute( - ggml_backend_t backend, ggml_cgraph* cgraph) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; +static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; ggml_cann_set_device(cann_ctx->device); g_nz_workspaces[cann_ctx->device].clear(); @@ -2382,7 +2287,7 @@ static enum ggml_status ggml_backend_cann_graph_compute( cann_ctx->rope_cache.cached = false; #ifdef USE_ACL_GRAPH - bool use_cann_graph = true; + bool use_cann_graph = true; bool cann_graph_update_required = false; static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); @@ -2413,15 +2318,10 @@ static enum ggml_status ggml_backend_cann_graph_compute( } } #else - bool use_cann_graph = false; + bool use_cann_graph = false; bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH - evaluate_and_capture_cann_graph( - cann_ctx, - cgraph, - use_cann_graph, - cann_graph_update_required - ); + evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -2438,8 +2338,7 @@ static enum ggml_status ggml_backend_cann_graph_compute( * @return bool Returns true if the operation is supported by the backend, * otherwise false. */ -static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, - const ggml_tensor* op) { +static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -2474,24 +2373,24 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } break; - case GGML_OP_MUL_MAT: { - switch (op->src[0]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return true; - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: + case GGML_OP_MUL_MAT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: #ifdef ASCEND_310P - // Q4 && Q8 per group is not support on 310p device - return false; + // Q4 && Q8 per group is not support on 310p device + return false; #endif - // only support contiguous for quantized types. - return ggml_is_contiguous(op->src[0]) && - ggml_is_contiguous(op->src[1]); - default: - return false; + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + default: + return false; + } } - } case GGML_OP_MUL_MAT_ID: switch (op->src[0]->type) { case GGML_TYPE_F16: @@ -2504,99 +2403,107 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; #endif // only support contiguous for quantized types. - return ggml_is_contiguous(op->src[0]) && - ggml_is_contiguous(op->src[1]); + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); default: return false; } // embedding - case GGML_OP_GET_ROWS: { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q8_0: - return true; - default: + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } + break; + case GGML_OP_SET_ROWS: + { + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + } + break; + case GGML_OP_CPY: + { + ggml_tensor * src = op->src[0]; + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || + (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { + // only support F32 and F16. return false; + } + return true; } - } break; - case GGML_OP_SET_ROWS: { - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: + break; + case GGML_OP_CONT: + { + // TODO: support GGML_TYPE_BF16 + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + } + case GGML_OP_ROPE: + { + // TODO: with ops-test v == 1 + // TODO: n_dims <= ne0 + if (op->src[0]->ne[0] != op->op_params[1]) { return false; - } - } break; - case GGML_OP_CPY: { - ggml_tensor *src = op->src[0]; - if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || - (src->type != GGML_TYPE_F32 && - src->type != GGML_TYPE_F16)) { - // only support F32 and F16. - return false; - } - return true; - } break; - case GGML_OP_CONT: { - // TODO: support GGML_TYPE_BF16 - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - } - case GGML_OP_ROPE: { - // TODO: with ops-test v == 1 - // TODO: n_dims <= ne0 - if (op->src[0]->ne[0] != op->op_params[1]) { - return false; - } + } - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } #ifdef ASCEND_310P - if(!ggml_is_contiguous(op->src[0])){ - return false; - } + if (!ggml_is_contiguous(op->src[0])) { + return false; + } #endif - return true; - } - case GGML_OP_UPSCALE: { - // aclnnUpsampleNearest2dGetWorkspaceSize not support - // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal - if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { - return false; + return true; } - if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { - return false; + case GGML_OP_UPSCALE: + { + // aclnnUpsampleNearest2dGetWorkspaceSize not support + // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal + if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { + return false; + } + if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { + return false; + } + return true; } - return true; - } - case GGML_OP_POOL_2D: { - const int32_t * opts = (const int32_t *) op->op_params; + case GGML_OP_POOL_2D: + { + const int32_t * opts = (const int32_t *) op->op_params; #ifdef ASCEND_310P - enum ggml_op_pool opt = static_cast(opts[0]); - if(opt == GGML_OP_POOL_MAX){ - return false; - } + enum ggml_op_pool opt = static_cast(opts[0]); + if (opt == GGML_OP_POOL_MAX) { + return false; + } #endif - const int k0 = opts[1]; - const int k1 = opts[2]; - const int p0 = opts[5]; - const int p1 = opts[6]; - // value of paddingH should be at most half of kernelH - // value of paddingW should be at most half of kernelW - return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); - } + const int k0 = opts[1]; + const int k1 = opts[2]; + const int p0 = opts[5]; + const int p1 = opts[6]; + // value of paddingH should be at most half of kernelH + // value of paddingW should be at most half of kernelW + return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); + } case GGML_OP_DUP: case GGML_OP_SUM: case GGML_OP_IM2COL: @@ -2639,48 +2546,50 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return (op->src[0]->ne[0] - 1) <= 255; case GGML_OP_SCALE: float bias; - memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float)); - return bias == 0.0f; // TODO: support bias != 0.0f + memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float)); + return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support attention sinks [TAG_ATTN_SINKS] if (op->src[2]) { return false; } return true; - case GGML_OP_FLASH_ATTN_EXT:{ + case GGML_OP_FLASH_ATTN_EXT: + { #ifdef ASCEND_310P - // FA not support on 310p device - return false; + // FA not support on 310p device + return false; #endif - // derived from [ggml-cuda.cu] - if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ - return false; + // derived from [ggml-cuda.cu] + if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) { + return false; + } + if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && + op->src[1]->type != GGML_TYPE_BF16) { + return false; + } + if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) { + return false; + } + // TODO: support attention sinks [TAG_ATTN_SINKS] + if (op->src[4]) { + return false; + } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] % 16 != 0) { + // TODO: padding to support + return false; + } + float logitSoftcap = 0.0f; + memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); + if (logitSoftcap != 0.0f) { + return false; + } + return true; } - if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){ - return false; - } - if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ - return false; - } - // TODO: support attention sinks [TAG_ATTN_SINKS] - if (op->src[4]) { - return false; - } - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet - return false; - } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } - float logitSoftcap = 0.0f; - memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float)); - if(logitSoftcap != 0.0f) { - return false; - } - return true; - } default: return false; } @@ -2717,8 +2626,7 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { * @return bool Returns true if the operation should be offloaded, otherwise * false. */ -static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, - const ggml_tensor* op) { +static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { const int min_batch_size = 32; GGML_UNUSED(dev); @@ -2734,9 +2642,8 @@ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, * @param event Pointer to the event structure to be recorded. */ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream())); + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; + ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream())); } /** @@ -2749,13 +2656,10 @@ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_ * @param event Pointer to the event structure that the backend needs to wait * for. */ -static void ggml_backend_cann_event_wait(ggml_backend_t backend, - ggml_backend_event_t event) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; +static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; if (ggml_backend_is_cann(backend)) { - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), - (aclrtEvent)event->context)); + ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context)); } else { GGML_ABORT("fatal error"); } @@ -2794,30 +2698,30 @@ static const ggml_backend_i ggml_backend_cann_interface = { * @return A pointer to the static GUID. */ static ggml_guid_t ggml_backend_cann_guid() { - static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34, - 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64}; + static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34, + 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 }; return &guid; } // backend device struct ggml_backend_cann_device_context { - int device; + int device; std::string name; std::string description; }; static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) { - ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context; return ctx->name.c_str(); } -static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) { - ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; +static const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context; return ctx->description.c_str(); } static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context; ggml_backend_cann_get_device_memory(ctx->device, free, total); } @@ -2844,7 +2748,7 @@ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_back static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); - ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context; return ggml_backend_cann_init(ctx->device); } @@ -2861,19 +2765,17 @@ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, cons * @return bool Returns true if the CANN backend supports the buffer type, * otherwise false. */ -static bool ggml_backend_cann_supports_buft( - ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { +static bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { if (ggml_backend_buft_is_cann(buft)) { - ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; - ggml_backend_cann_buffer_type_context * buft_ctx = - (ggml_backend_cann_buffer_type_context *)buft->context; + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context; + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; return buft_ctx->device == dev_ctx->device; } return false; } static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) { - ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context; return ggml_backend_cann_buffer_type(ctx->device); } @@ -2892,9 +2794,8 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type( * @param backend Pointer to the CANN backend. * @return ggml_backend_event_t Returns a pointer to the new event structure. */ -static ggml_backend_event_t ggml_backend_cann_device_event_new( - ggml_backend_dev_t dev) { - ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; +static ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context; ggml_cann_set_device(dev_ctx->device); @@ -2916,7 +2817,7 @@ static ggml_backend_event_t ggml_backend_cann_device_event_new( * @param event Pointer to the event structure to be freed. */ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { - ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); + ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context)); delete event; GGML_UNUSED(dev); @@ -2930,7 +2831,7 @@ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_bac * @param event Pointer to the event structure to be synchronized. */ static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { - ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); + ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context)); GGML_UNUSED(dev); } @@ -2941,10 +2842,10 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = { /* .get_memory = */ ggml_backend_cann_device_get_memory, /* .get_type = */ ggml_backend_cann_device_get_type, /* .get_props = */ ggml_backend_cann_device_get_props, - /* .init_backend = */ ggml_backend_cann_device_init, // called for every card + /* .init_backend = */ ggml_backend_cann_device_init, // called for every card /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, // not supported for CANN + /* .buffer_from_host_ptr = */ NULL, // not supported for CANN /* .supports_op = */ ggml_backend_cann_supports_op, /* .supports_buft = */ ggml_backend_cann_supports_buft, /* .offload_op = */ ggml_backend_cann_offload_op, @@ -2953,7 +2854,6 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = { /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize, }; - // backend reg struct ggml_backend_cann_reg_context { std::vector devices; @@ -2965,12 +2865,12 @@ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) { } static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) { - ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context; return ctx->devices.size(); } static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) { - ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context; GGML_ASSERT(index < ctx->devices.size()); return ctx->devices[index]; } @@ -2992,34 +2892,30 @@ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = { // backend registry, called only once for cann backend ggml_backend_reg_t ggml_backend_cann_reg() { static ggml_backend_reg reg; - static bool initialized = false; + static bool initialized = false; { - static std::mutex mutex; + static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { aclInit(nullptr); ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context; for (int i = 0; i < ggml_cann_info().device_count; i++) { - ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context(); - dev_ctx->description = aclrtGetSocName(); - dev_ctx->device = i; - dev_ctx->name = GGML_CANN_NAME + std::to_string(i); + ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context(); + dev_ctx->description = aclrtGetSocName(); + dev_ctx->device = i; + dev_ctx->name = GGML_CANN_NAME + std::to_string(i); ggml_cann_set_device(i); - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_cann_device_interface, - /* .reg = */ ®, - /* .context = */ dev_ctx - }; + ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx }; ctx->devices.push_back(dev); } - reg = ggml_backend_reg { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_cann_reg_interface, - /* .context = */ ctx - }; + reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_cann_reg_interface, + /* .context = */ ctx }; } initialized = true; @@ -3035,39 +2931,36 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) { return nullptr; } - ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device); + ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device); if (ctx == nullptr) { GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return nullptr; } ggml_cann_set_device(ctx->device); ggml_backend_t cann_backend = - new ggml_backend{/* .guid = */ ggml_backend_cann_guid(), - /* .interface = */ ggml_backend_cann_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device), - /* .context = */ ctx}; + new ggml_backend{ /* .guid = */ ggml_backend_cann_guid(), + /* .interface = */ ggml_backend_cann_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device), + /* .context = */ ctx }; return cann_backend; } bool ggml_backend_is_cann(ggml_backend_t backend) { - return backend != NULL && - ggml_guid_matches(backend->guid, ggml_backend_cann_guid()); + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid()); } int32_t ggml_backend_cann_get_device_count() { return ggml_cann_info().device_count; } -void ggml_backend_cann_get_device_description( - int32_t device, char* description, size_t description_size) { +void ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) { ggml_cann_set_device(device); - const char* soc_name = aclrtGetSocName(); + const char * soc_name = aclrtGetSocName(); snprintf(description, description_size, "%s", soc_name); } -void ggml_backend_cann_get_device_memory(int32_t device, size_t* free, - size_t* total) { +void ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) { ggml_cann_set_device(device); ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total)); } From 7bb53032b3185ba2dd37c3bc8e5cc7e4d44c201e Mon Sep 17 00:00:00 2001 From: GittyBurstein Date: Thu, 16 Oct 2025 16:26:21 +0300 Subject: [PATCH 085/104] sycl : add ARANGE operator (llama/16362) * SYCL: update element-wise ops and presets * clean arange * Re-trigger CI --------- Co-authored-by: Gitty Burstein --- ggml/src/ggml-sycl/element_wise.cpp | 32 +++++++++++++++++++++++++++++ ggml/src/ggml-sycl/element_wise.hpp | 2 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 5 +++++ ggml/src/ggml-sycl/presets.hpp | 1 + 4 files changed, 40 insertions(+) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index aeeb3875..58f5125c 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -397,6 +397,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, }); } +template +static void arange_kernel(T * dst, const int k, T start, T step, + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = start + static_cast(i) * step; + } +} + template static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, @@ -565,6 +573,25 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx } +static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + float start, stop, step; + memcpy(&start, dst->op_params, sizeof(float)); + memcpy(&stop, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&step, (float *) dst->op_params + 2, sizeof(float)); + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + float * dst_ptr = (float *)dst->data; + const int k = (int)ggml_nelements(dst); + const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE), + sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + arange_kernel(dst_ptr, k, start, step, item_ct1); + }); +} + } // namespace ggml_sycl_detail @@ -1090,3 +1117,8 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_geglu_quick(ctx, dst); } + +void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); + ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 43474317..ed96c55f 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -81,4 +81,6 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f3407a81..9e557972 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3832,6 +3832,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_ARANGE: + ggml_sycl_arange(ctx, dst); + break; default: return false; } @@ -4478,6 +4481,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_ARANGE: + return op->type == GGML_TYPE_F32; default: return false; } diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index af189072..0814bd79 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -49,6 +49,7 @@ #define SYCL_ARGMAX_BLOCK_SIZE 256 #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256 #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 +#define SYCL_ARANGE_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec #ifndef GGML_SYCL_DMMV_X From 82332cea27bf2f0c353a6d2c7ca93b18903a7995 Mon Sep 17 00:00:00 2001 From: GittyBurstein Date: Fri, 17 Oct 2025 05:36:40 +0300 Subject: [PATCH 086/104] SYCL SET operator optimized for F32 tensors (llama/16350) * SYCL/SET: implement operator + wire-up; docs/ops updates; element_wise & ggml-sycl changes * sycl(SET): re-apply post-rebase; revert manual docs/ops.md; style cleanups * move SET op to standalone file, GPU-only implementation * Update SYCL SET operator for F32 * ci: fix editorconfig issues (LF endings, trailing spaces, final newline) * fixed ggml-sycl.cpp --------- Co-authored-by: Gitty Burstein --- ggml/src/ggml-sycl/ggml-sycl.cpp | 10 +++++ ggml/src/ggml-sycl/presets.hpp | 1 + ggml/src/ggml-sycl/set.cpp | 73 ++++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/set.hpp | 5 +++ 4 files changed, 89 insertions(+) create mode 100644 ggml/src/ggml-sycl/set.cpp create mode 100644 ggml/src/ggml-sycl/set.hpp diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 9e557972..a7e077ec 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -42,6 +42,7 @@ #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/set_rows.hpp" +#include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/quantize.hpp" @@ -3619,6 +3620,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_sycl_get_rows(ctx, dst); break; + case GGML_OP_SET: + ggml_sycl_op_set(ctx, dst); + break; case GGML_OP_SET_ROWS: ggml_sycl_op_set_rows(ctx, dst); break; @@ -4331,6 +4335,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return false; } } + case GGML_OP_SET: + return (op->type == GGML_TYPE_F32) && + (op->src[0] && op->src[1]) && + (op->src[0]->type == GGML_TYPE_F32) && + (op->src[1]->type == GGML_TYPE_F32); + case GGML_OP_SET_ROWS: { return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 0814bd79..b6517374 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -31,6 +31,7 @@ #define SYCL_SQRT_BLOCK_SIZE 256 #define SYCL_SIN_BLOCK_SIZE 256 #define SYCL_SQR_BLOCK_SIZE 256 +#define SYCL_SET_BLOCK_SIZE 256 #define SYCL_CPY_BLOCK_SIZE 32 #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-sycl/set.cpp b/ggml/src/ggml-sycl/set.cpp new file mode 100644 index 00000000..381326d2 --- /dev/null +++ b/ggml/src/ggml-sycl/set.cpp @@ -0,0 +1,73 @@ +#include "presets.hpp" +#include "common.hpp" +#include "ggml.h" +#include "set.hpp" +#include +#include +using namespace sycl; + +// Internal function: perform element-wise set operation for each thread +inline void set_f32(const float* src, float* dst, + const int64_t ne0, const int64_t ne1, + const int64_t ne2, const int64_t ne3, + const int64_t nb[3], const int64_t src_nb[3], + const int64_t offset_elem, + const nd_item<1>& item) +{ + const size_t idx = item.get_global_id(0); + const size_t total = ne0 * ne1 * ne2 * ne3; + if (idx >= total) return; + + // Convert linear index to 4D indices + const size_t i3 = idx / (ne2 * ne1 * ne0); + const size_t rem = idx % (ne2 * ne1 * ne0); + const size_t i2 = rem / (ne1 * ne0); + const size_t rem2 = rem % (ne1 * ne0); + const size_t i1 = rem2 / ne0; + const size_t i0 = rem2 % ne0; + + // Compute source and destination indices and copy + dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] = + src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]]; +} + +// Main function: prepare GPU queue and launch parallel_for +void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor* src0 = dst->src[0]; + const ggml_tensor* src1 = dst->src[1]; + + // Ensure shapes and types are compatible + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32); + + const int32_t* opts = (const int32_t*) dst->op_params; + const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)}; + const int64_t offset_elem = opts[3] / sizeof(float); + const bool inplace = opts[4]; + + float* dst_ptr = (float*) dst->data; + const float* src0_ptr = (const float*) src0->data; + const float* src1_ptr = (const float*) src1->data; + + queue_ptr stream = ctx.stream(); + + // Copy src0 to dst if not inplace + if (!inplace) + stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst)); + + const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]}; + const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)}; + + const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3]; + const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE; + + // Copy src0 to dst if not inplace + stream->parallel_for( + nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)), + [=](nd_item<1> item) { + set_f32(src1_ptr, dst_ptr, + ne[0], ne[1], ne[2], ne[3], + nb, src_nb, offset_elem, item); } + ); +} diff --git a/ggml/src/ggml-sycl/set.hpp b/ggml/src/ggml-sycl/set.hpp new file mode 100644 index 00000000..657d7ac9 --- /dev/null +++ b/ggml/src/ggml-sycl/set.hpp @@ -0,0 +1,5 @@ +#pragma once +#include "backend.hpp" +#include "ggml.h" + +void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst); From 0ae492641cc25232a84b0c98d0c7a966789c6bbc Mon Sep 17 00:00:00 2001 From: Ilia Ilmer Date: Fri, 17 Oct 2025 02:33:58 -0400 Subject: [PATCH 087/104] metal : add `CONV_TRANSPOSE_2D` (llama/16542) * initial: headers and metal-device.cpp updates * adding conv_transpose_2d * fix type * fix type: int32->int64 * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov * add checks for src[0] and src[1]; add type checks * Update ggml-metal.metal Co-authored-by: Georgi Gerganov * add more tests, add optimization to threading * add dynamic memory allocation in metal --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 25 +++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 5 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 13 ++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 60 +++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 91 +++++++++++++++++++++++ 7 files changed, 196 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 866cd2da..75811634 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 28ae2e17..4d582974 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3c83abe..360fbe19 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SCALE: case GGML_OP_CONV_TRANSPOSE_1D: return true; + case GGML_OP_CONV_TRANSPOSE_2D: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32; case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fa2d82ce..96f43d26 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -514,6 +514,19 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_conv_transpose_1d; +typedef struct { + int32_t IC; + int32_t IH; + int32_t IW; + int32_t KH; + int32_t KW; + int32_t OC; + int32_t s0; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_conv_transpose_2d; + typedef struct { uint64_t ofs0; uint64_t ofs1; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 4f9f6bda..7a85edbd 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[2]; + const int32_t IH = op->src[1]->ne[1]; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = op->src[0]->ne[1]; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OW = op->ne[0]; + const int32_t OH = op->ne[1]; + const int32_t OC = op->ne[2]; + + ggml_metal_kargs_conv_transpose_2d args = { + /*.IC =*/ IC, + /*.IH =*/ IH, + /*.IW =*/ IW, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.OC =*/ OC, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + // Metal requires buffer size to be multiple of 16 bytes + const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1); + + return 1; +} + int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f3527386..0d9cb8af 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 496610b1..2c2f0141 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); + +typedef void (conv_transpose_2d_t)( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const T * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t out_x = tgpig[0]; + const int64_t out_y = tgpig[1]; + const int64_t out_c = tgpig[2]; + + const int64_t kw = tpitg[0]; + const int64_t kh = tpitg[1]; + + float v = 0.0f; + + for (int64_t in_c = 0; in_c < args.IC; in_c++) { + int64_t in_y = out_y - kh; + + if (in_y < 0 || in_y % args.s0) continue; + + in_y /= args.s0; + + if (in_y >= args.IH) continue; + + int64_t in_x = out_x - kw; + + if (in_x < 0 || in_x % args.s0) continue; + + in_x /= args.s0; + + if (in_x >= args.IW) continue; + + const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x; + const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw; + + v += (float)src0[kernel_idx] * src1[input_idx]; + } + + const uint tid = tpitg.y * ntg.x + tpitg.x; + shared_sum[tid] = v; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid == 0) { + float total = 0.0f; + const uint num_threads = ntg.x * ntg.y; + for (uint i = 0; i < num_threads; i++) { + total += shared_sum[i]; + } + + device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2); + dst_ptr[0] = total; + } +} + +template [[host_name("kernel_conv_transpose_2d_f32_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template [[host_name("kernel_conv_transpose_2d_f16_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const half * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + kernel void kernel_upscale_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, From 4a384826a821e5fd25a6046d926d0fba3e18e8f0 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 17 Oct 2025 02:31:04 -0500 Subject: [PATCH 088/104] vulkan: fix debug build (add_rms_len/data not found) (llama/16624) --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 184f3f3a..32f272e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -959,7 +959,7 @@ void write_output_files() { } std::string suffixes[2] = {"_f32", "_f16"}; - for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; From 328263f8fdf1959cb6a785f35a855e85a8197a55 Mon Sep 17 00:00:00 2001 From: muggle-stack Date: Fri, 17 Oct 2025 18:01:23 +0800 Subject: [PATCH 089/104] ggml : fix SpaceMit IME array out-of-bounds in task assignment (llama/16629) Fix incorrect task-to-batch index calculation in the quantization phase. The bug caused out-of-bounds access to qnbitgemm_args array when compute_idx exceeded per_gemm_block_count_m, leading to invalid pointer dereferences and SIGBUS errors. Correctly map tasks to batches by dividing compute_idx by per_gemm_block_count_m instead of block_size_m. Example: batch_feature=1, gemm_m=30, block_size_m=4 per_gemm_block_count_m = 8, task_count = 8 Old: gemm_idx = 4/4 = 1 (out of bounds New: gemm_idx = 4/8 = 0 (correct) Tested on SpaceMit K1 RISC-V64 with qwen2.5:0.5b model. Co-authored-by: muggle --- ggml/src/ggml-cpu/spacemit/ime.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 54d3dece..91fe1925 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -485,8 +485,9 @@ template class tensor_ int32_t start = ith * task_per_thread; int32_t end = std::min((ith + 1) * task_per_thread, task_count); for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / block_size_m; - int32_t m_idx = compute_idx % block_size_m * block_size_m; + int32_t gemm_idx = compute_idx / per_gemm_block_count_m; + int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; + int32_t m_idx = block_idx_in_gemm * block_size_m; const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); From d22008b631dd3de4748cbf426b33e3d273a13034 Mon Sep 17 00:00:00 2001 From: Giuseppe Scrivano Date: Fri, 17 Oct 2025 14:23:47 +0200 Subject: [PATCH 090/104] vulkan: Add State Space Model (SSM) Operations Support (llama/16463) * vulkan: implement SSM scan operation Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano * vulkan: implement SSM conv operation Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano --------- Signed-off-by: Giuseppe Scrivano --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 227 +++++++++++++++++- .../ggml-vulkan/vulkan-shaders/ssm_conv.comp | 44 ++++ .../ggml-vulkan/vulkan-shaders/ssm_scan.comp | 125 ++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + 4 files changed, 394 insertions(+), 6 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1674dc66..bc703611 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -582,6 +582,9 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_ssm_scan_f32_d128; + vk_pipeline pipeline_ssm_scan_f32_d256; + vk_pipeline pipeline_ssm_conv_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -1087,6 +1090,19 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_ssm_scan_push_constants { + uint32_t nb02, nb03, nb12, nb13; + uint32_t nb21, nb22, nb31; + uint32_t nb42, nb43, nb52, nb53; + uint32_t s_off; + uint32_t n_head, d_head, n_group, n_tok; +}; +struct vk_op_ssm_conv_push_constants { + uint32_t nb01, nb02; + uint32_t nb11; + uint32_t dst_nb0, dst_nb1, dst_nb2; + uint32_t nc, ncs, nr, n_t, n_s; +}; struct vk_op_conv2d_push_constants { uint32_t Cout; @@ -3591,6 +3607,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -8098,6 +8119,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_SSM_SCAN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t d_state = src0->ne[0]; + if (d_state == 128) { + return ctx->device->pipeline_ssm_scan_f32_d128; + } else if (d_state == 256) { + return ctx->device->pipeline_ssm_scan_f32_d256; + } + } + return nullptr; + case GGML_OP_SSM_CONV: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_ssm_conv_f32; + } + return nullptr; case GGML_OP_OPT_STEP_ADAMW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_opt_step_adamw_f32; @@ -8592,6 +8628,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; + case GGML_OP_SSM_CONV: + { + const uint32_t nr = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + elements = { nr, n_t, n_s }; + } + break; default: elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; break; @@ -9038,6 +9082,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t head_dim = src0->ne[1]; + const uint32_t n_head = src1->ne[1]; + const uint32_t n_group = src4->ne[1]; + const uint32_t n_tok = src1->ne[2]; + const uint32_t n_seq = src1->ne[3]; + + bool is_mamba2 = (src3->nb[1] == sizeof(float)); + GGML_ASSERT(is_mamba2); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + const vk_op_ssm_scan_push_constants pc = { + (uint32_t)src0->nb[2], (uint32_t)src0->nb[3], + (uint32_t)src1->nb[2], (uint32_t)src1->nb[3], + (uint32_t)src2->nb[1], (uint32_t)src2->nb[2], + (uint32_t)src3->nb[1], + (uint32_t)src4->nb[2], (uint32_t)src4->nb[3], + (uint32_t)src5->nb[2], (uint32_t)src5->nb[3], + (uint32_t)s_off, + n_head, head_dim, n_group, n_tok + }; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC]; + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } + + vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr }; + size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 }; + bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false }; + + if (ctx->device->uma) { + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + dst_uma = d_D != nullptr; + } + + if (!dst_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } + } + + size_t dst_size = ggml_nbytes(dst); + size_t src_sizes[GGML_MAX_SRC]; + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + } + + std::array elements; + + const int splitH = 16; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t num_workgroups_y = n_seq; + elements = { num_workgroups_x, num_workgroups_y, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); +} + +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, { + (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], + (uint32_t)src1->nb[1], + (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], + (uint32_t)src1->ne[0], + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2], + }, dryrun); +} + static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { const ggml_tensor * x = dst->src[0]; const ggml_tensor * g = dst->src[1]; @@ -10870,6 +11025,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SSM_SCAN: + case GGML_OP_SSM_CONV: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: @@ -11287,6 +11444,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_SSM_SCAN: + ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_SSM_CONV: + ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun); + + break; + case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); @@ -11398,6 +11565,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SSM_SCAN: + case GGML_OP_SSM_CONV: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: @@ -12879,6 +13048,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_SSM_SCAN: + { + for (int i = 0; i < 6; i++) { + if (op->src[i] && ggml_is_quantized(op->src[i]->type)) { + return false; + } + } + if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) { + return false; + } + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + + const uint32_t d_state = op->src[0]->ne[0]; + const uint32_t head_dim = op->src[0]->ne[1]; + + bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float)); + if (!is_mamba2) { + return false; + } + + if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + const uint32_t SPLIT_H = 16; + + size_t stateC_size = SPLIT_H * d_state * sizeof(float); + + if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } + + return true; + } + case GGML_OP_SSM_CONV: + return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -13223,14 +13433,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * struct ggml_context * ggml_ctx = ggml_init(iparams); - std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; - std::array src_size = {0, 0, 0, 0, 0, 0}; - std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; - const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {}; + std::array src_buffer = {}; + const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"}; struct ggml_tensor * tensor_clone = nullptr; - for (int i = 0; i < 6; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { ggml_tensor * srci = tensor->src[i]; if (fused_rms_norm_mul) { rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; @@ -13537,6 +13747,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[2]); } else if (tensor->op == GGML_OP_ADD_ID) { tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SSM_SCAN) { + tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], + src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_SSM_CONV) { + tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -13558,7 +13773,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - for (int i = 0; i < 6; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (src_buffer[i] != nullptr) { free(src_buffer[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp new file mode 100644 index 00000000..d62696bc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -0,0 +1,44 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float src0[]; }; +layout(binding = 1) readonly buffer Src1 { float src1[]; }; +layout(binding = 2) buffer Dst { float dst[]; }; + +layout(push_constant) uniform PushConstants { + uint nb01; uint nb02; + uint nb11; + uint dst_nb0; uint dst_nb1; uint dst_nb2; + uint nc; uint ncs; uint nr; uint n_t; uint n_s; +}; + +void main() { + const uint global_thread_id = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y; + const uint i3 = gl_WorkGroupID.z; + + if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + return; + } + + const uint i1 = global_thread_id; + const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); + const uint src1_base = i1 * (nb11 / 4); + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; + + float sum = 0.0; + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + const uint src0_idx = src0_base + i0; + const uint src1_idx = src1_base + i0; + sum += src0[src0_idx] * src1[src1_idx]; + } + + dst[dst_idx] = sum; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp new file mode 100644 index 00000000..12bd1745 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -0,0 +1,125 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout(constant_id = 0) const uint D_STATE = 128; +layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 2) const uint SPLIT_H = 16; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float s0[]; }; +layout(binding = 1) readonly buffer Src1 { float x[]; }; +layout(binding = 2) readonly buffer Src2 { float dt[]; }; +layout(binding = 3) readonly buffer Src3 { float A[]; }; +layout(binding = 4) readonly buffer Src4 { float B[]; }; +layout(binding = 5) readonly buffer Src5 { float C[]; }; +layout(binding = 6) readonly buffer Src6 { int ids[]; }; +layout(binding = 7) buffer Dst { float d[]; }; + +layout(push_constant) uniform PushConstants { + uint nb02; uint nb03; uint nb12; uint nb13; + uint nb21; uint nb22; uint nb31; + uint nb42; uint nb43; uint nb52; uint nb53; + uint s_off; + uint n_head; + uint d_head; + uint n_group; + uint n_tok; +}; + +float softplus(float x) { + if (x <= 20.0) { + return log(1.0 + exp(x)); + } else { + return x; + } +} + +shared float stateC[SPLIT_H * D_STATE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; + const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; + + const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; + const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; + const uint A_base_idx = (head_idx * nb31) / 4; + const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; + const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + + const uint stride_x = nb12 / 4; + const uint stride_dt = nb21 / 4; + const uint stride_B = nb42 / 4; + const uint stride_C = nb52 / 4; + const uint stride_y = n_head * d_head; + + float state[SPLIT_H]; + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + state[j] = s0[s0_base_idx + j * D_STATE + tid]; + } + + for (uint i = 0; i < n_tok; i++) { + const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + + const float dA = exp(dt_soft_plus * A[A_base_idx]); + + const float B_val = B[B_base_idx + i * stride_B + tid]; + const float C_val = C[C_base_idx + i * stride_C + tid]; + + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + + state[j] = (state[j] * dA) + (B_val * x_dt); + + stateC[j * D_STATE + tid] = state[j] * C_val; + } + + barrier(); + for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) { + [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) { + const uint k = (tid % (w >> 1)) + + (D_STATE * (tid / (w >> 1))) + + j * D_STATE * (D_STATE / (w >> 1)); + if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) { + stateC[k] += stateC[k + (w >> 1)]; + } + } + barrier(); + } + + [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) { + const uint idx = (tid % SUBGROUP_SIZE) + + D_STATE * (tid / SUBGROUP_SIZE) + + j * D_STATE * (D_STATE / SUBGROUP_SIZE); + + uint lane = tid % SUBGROUP_SIZE; + + [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { + if (idx + offset < SPLIT_H * D_STATE) { + stateC[idx] += stateC[idx + offset]; + } + barrier(); + } + + if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) { + const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); + d[y_base_idx + i * stride_y + k] = stateC[idx]; + } + } + + barrier(); + } + + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + d[s_base_idx + j * D_STATE + tid] = state[j]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 32f272e9..1d04a812 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -916,6 +916,10 @@ void process_shaders() { string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + + string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); } From 6aa18cccd87893cc86773fc93a5dc54a52692c70 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 17 Oct 2025 18:02:52 +0300 Subject: [PATCH 091/104] rpc : report actual free memory (llama/16616) * rpc : report actual free memory Start reporting the free memory on every device instead of using fixed values. Now llama-cli users can get a nice memory breakdown when using RPC devices. * drop --mem in rpc-server --- ggml/include/ggml-rpc.h | 3 +-- ggml/src/ggml-rpc/ggml-rpc.cpp | 39 +++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 72eff002..e6dca3f6 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, - size_t n_threads, size_t n_devices, - ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index aad48d62..a38df5a9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -939,6 +939,7 @@ public: bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); private: bool get_cached_file(uint64_t hash, std::vector & data); @@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return true; } +bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + size_t free, total; + ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]); + ggml_backend_dev_memory(dev, &free, &total); + response.free_mem = free; + response.total_mem = total; + LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem); + return true; +} + rpc_server::~rpc_server() { for (auto buffer : buffers) { ggml_backend_buffer_free(buffer); @@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + sockfd_t sockfd) { rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, &request, sizeof(request))) { return; } - auto dev_id = request.device; - if (dev_id >= backends.size()) { + rpc_msg_get_device_memory_rsp response; + if (!server.get_device_memory(request, response)) { return; } - rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem[dev_id]; - response.total_mem = total_mem[dev_id]; - LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, - response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector & backends, const } void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, - size_t n_threads, size_t n_devices, - ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { - if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) { + if (n_devices == 0 || devices == nullptr) { fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); return; } std::vector backends; - std::vector free_mem_vec(free_mem, free_mem + n_devices); - std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, @@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir printf("Devices:\n"); for (size_t i = 0; i < n_devices; i++) { auto dev = devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + total / 1024 / 1024, free / 1024 / 1024); auto backend = ggml_backend_dev_init(dev, nullptr); if (!backend) { fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); @@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); + rpc_serve_client(backends, cache_dir, client_socket->fd); printf("Client connection closed\n"); fflush(stdout); } From 8ffdf4bd963bfe4437f35620d884884055a68f64 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 17 Oct 2025 17:55:32 -0700 Subject: [PATCH 092/104] opencl: transposed gemm/gemv moe kernel with mxfp4,f32 (llama/16602) * opencl: transposed gemm/gemv moe kernel with mxfp4,f32 * add restore kernel for moe transpose * fix trailing whitespaces * resolve compilation warnings --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 213 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 42 ++++ .../ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl | 162 +++++++++++++ .../ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl | 156 +++++++++++++ 5 files changed, 567 insertions(+), 8 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 6f6bba55..d3d97f37 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_q8_0_f32_flat mul_mv_id_mxfp4_f32 mul_mv_id_mxfp4_f32_flat + gemm_moe_mxfp4_f32 + gemv_moe_mxfp4_f32 mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q8_0_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2ec896fd..d9876e69 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -402,6 +402,7 @@ struct ggml_backend_opencl_context { cl_program program_conv_2d_f32; cl_program program_conv_2d_f16_f32; cl_program program_tsembd; + cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32; cl_program program_mul_mv_id_q4_0_f32_8x_flat; cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat; cl_program program_mul_mv_id_mxfp4_f32; @@ -452,7 +453,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32_tiled; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; - cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4; + cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; @@ -475,6 +476,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_timestep_embedding; + cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; @@ -559,14 +561,14 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "[\n"); for (const ProfilingInfo & info : profiling_info) { - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_submit/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_start/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } fclose(ftrace); @@ -777,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); @@ -1991,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); GGML_LOG_CONT("."); } + + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -cl-fast-relaxed-math"; + + // gemv_moe_mxfp4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl"); +#endif + backend_ctx->program_gemv_moe_mxfp4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl"); +#endif + backend_ctx->program_gemm_moe_mxfp4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -3299,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c tensor->ne[2] == 1 && tensor->ne[3] == 1; } +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); +} + static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -3601,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + tensor->extra = extra; + + return; + } +#endif cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; cl_event evt; CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); @@ -3624,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, { extra->q } }; extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); - tensor->extra = extra; return; @@ -3751,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); @@ -7553,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne21 = src2->ne[1]; const cl_ulong nb21 = src2->nb[1]; + const cl_ulong nb20 = src2->nb[0]; const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; @@ -7692,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } case GGML_TYPE_MXFP4: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + int tile_size = 320; + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32; + + // preprocess router table + int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size; + void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short)); + void * host_src2 = malloc(ne21 * nb21); + CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL)); + int total_experts = nb21 / nb20; + int out_idx = 0; + for (int i_expert = 0; i_expert < ne02; i_expert++) { + for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) { + for (int j = 0; j < ne21; j++) { + for (int i = 0; i < ne20; i++) { + int expert = ((int *)host_src2)[j * total_experts + i]; + if (i_expert == expert) { + ((short *)host_src2_reorder)[out_idx] = static_cast(expert); + ((short *)host_src2_reorder)[out_idx + 1] = static_cast(j * ne11 + (i % ne11)); + ((short *)host_src2_reorder)[out_idx + 2] = static_cast(j * ne20 + i); + ((short *)host_src2_reorder)[out_idx + 3] = static_cast(i_tile); + out_idx += 4; + } + } + } + } + } + buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(tile_size); + global_size[2] = static_cast(ne20 * ne21 * num_tiles_per_expert); + } + + // create a sub_buffer for src1 + cl_buffer_region region; + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + if (ne12 == 1) { + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + } else { + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size)); + } + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + return; + } // else fallback to generic kernel +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 045300eb..b26f9c5f 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4( } } +kernel void kernel_convert_block_mxfp4_trans( + global struct block_mxfp4 * src0, + __global uint4 * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + + dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0]; + dst_e[dst_blk_offset] = b->e; +} + kernel void kernel_restore_block_mxfp4( global uchar * src_q, global half * src_e, @@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4( } } +kernel void kernel_restore_block_mxfp4_trans( + __global uint4 * src_q, + __global uchar * src_e, + global struct block_mxfp4 * dst, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = dst + dst_blk_offset; + + ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset]; + b->e = src_e[src_blk_offset]; +} + //------------------------------------------------------------------------------ // block_q8_0 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl new file mode 100644 index 00000000..3917aa3f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl @@ -0,0 +1,162 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 2 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemm_moe_mxfp4_f32( + __global uint4 * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global ushort4 * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int tile_size +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + ushort4 router = src2[i20]; + ushort expert_id = router.x; + ushort i11 = router.y; + ushort i1 = router.z; + ushort tile_id = router.w; + + if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size + return; + } + + uint expert_offset = expert_id * ne00 * ne01 / 32; + uint tile_offset = expert_offset + tile_id * tile_size + i01; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + // load one block of q + uint4 regQ = src0_q[tile_offset + ib00 * ne01]; + // convert 8 fp4 to fp16 + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + uint offset = i11 * ne00 / 4 + ib00 * 8; + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + uchar regE = src0_e[tile_offset + ib00 * ne01]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + tile_id * tile_size + i1 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl new file mode 100644 index 00000000..b4b1e511 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32( + __global uint4 * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From 08345f15ece9bdc528770596c08b48144082e933 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 18 Oct 2025 17:52:53 +0800 Subject: [PATCH 093/104] CUDA: use registers instead of smem in topk-moe (llama/16647) Uses the technique used in the vulkan PR #16641. Neat trick! --- ggml/src/ggml-cuda/topk-moe.cu | 42 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index afe4aee2..c588da2b 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt_sum = 0.f; - extern __shared__ float data_topk_shared[]; - float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; + float output_weights[experts_per_thread]; for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; @@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + if ((k & (WARP_SIZE - 1)) == threadIdx.x) { + output_weights[k / WARP_SIZE] = max_val; + } + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { wt[max_expert / WARP_SIZE] = -INFINITY; - wt_shared_ptr[k] = max_val; - ids[k] = max_expert; + ids[k] = max_expert; if constexpr (with_norm) { wt_sum += max_val; } @@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * const float inv_sum = 1.0f / wt_sum; for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { - wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; + output_weights[i] *= inv_sum; } } - for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { - weights[i] = wt_shared_ptr[i]; +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = i * WARP_SIZE + threadIdx.x; + if (idx < n_expert_used) { + weights[idx] = output_weights[i]; + } } } @@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); - const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); - switch (n_expert) { case 1: topk_moe_cuda<1, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 2: topk_moe_cuda<2, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 4: topk_moe_cuda<4, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 8: topk_moe_cuda<8, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 16: topk_moe_cuda<16, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 32: topk_moe_cuda<32, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 64: topk_moe_cuda<64, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 128: topk_moe_cuda<128, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 256: topk_moe_cuda<256, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; case 512: topk_moe_cuda<512, with_norm> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used); break; default: GGML_ASSERT(false && "fatal error"); From 414901a42c4ac9998615d45ee4e0f6cfe3064377 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 18 Oct 2025 05:22:57 -0500 Subject: [PATCH 094/104] vulkan: Implement topk_moe fused shader, ported from CUDA (llama/16641) This is similar to the CUDA shader from #16130, but doesn't use shared memory and handles different subgroup sizes. --- ggml/src/ggml-impl.h | 13 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 266 +++++++++++++++++- .../ggml-vulkan/vulkan-shaders/topk_moe.comp | 139 +++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + 4 files changed, 412 insertions(+), 8 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index d0fb3bcc..18f095b8 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) +static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) { + const struct ggml_tensor * node = cgraph->nodes[node_idx]; + + size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) { + return 0; + } + return cgraph->use_counts[hash_pos]; +} + // return true if the node's results are only used by N other nodes // and can be fused into their calculations. static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) { const struct ggml_tensor * node = cgraph->nodes[node_idx]; // check the use count against how many we're replacing - size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) { + if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) { return false; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bc703611..21bd0522 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -385,6 +385,14 @@ enum shader_reduction_mode { static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); +static constexpr uint32_t num_topk_moe_pipelines = 10; + +static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; +static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; + struct vk_device_struct { std::recursive_mutex mutex; @@ -598,6 +606,9 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; + // [2] is {!norm, norm} + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; + std::vector all_pipelines; std::vector> pinned_memory; @@ -941,6 +952,11 @@ struct vk_op_multi_add_push_constants { static_assert(MAX_PARAMETER_COUNT == 12); static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); +struct vk_op_topk_moe_push_constants { + uint32_t n_rows; + uint32_t n_expert_used; +}; + struct vk_op_add_id_push_constants { uint32_t ne0; uint32_t ne1; @@ -3722,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (ctx->num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); + bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + return ctx->device->pipeline_topk_moe[idx][with_norm]; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -9589,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } +static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + + bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; + ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; + ggml_tensor * ids = cgraph->nodes[node_idx + 3]; + + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + const int n_expert_used = weights->ne[1]; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context; + ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_logits = nullptr; + size_t logits_buf_offset = 0; + vk_buffer d_weights = nullptr; + size_t weights_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool logits_uma = false; + bool weights_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset); + ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + logits_uma = d_logits != nullptr; + weights_uma = d_weights != nullptr; + ids_uma = d_ids != nullptr; + } + + if (!logits_uma) { + d_logits = logits_buf_ctx->dev_buffer; + logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs; + GGML_ASSERT(d_logits != nullptr); + } + if (!weights_uma) { + d_weights = weights_buf_ctx->dev_buffer; + weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs; + GGML_ASSERT(d_weights != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + + vk_op_topk_moe_push_constants pc; + pc.n_rows = n_rows; + pc.n_expert_used = n_expert_used; + + GGML_ASSERT(n_expert_used <= n_experts); + + const uint32_t rows_per_block = 4; + std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset), + ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset), + ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset), + }, pc, elements); +} + static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -11174,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); } - // Add the last fused node and all fused source nodes to the unsynchronized list. - const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; - ctx->unsynced_nodes_written.push_back(last_node); + // Add all fused nodes to the unsynchronized lists. for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list. + ctx->unsynced_nodes_written.push_back(cur_node); for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { if (!cur_node->src[j]) { continue; @@ -11345,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + if (ctx->num_additional_fused_ops) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + } break; case GGML_OP_SOFT_MAX_BACK: @@ -12141,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st return true; } +static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, bool with_norm) { + + if (with_norm) { + if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe_norm.size(); ++i) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { + return false; + } + } + } else { + if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { + return false; + } + for (size_t i = 0; i < topk_moe.size(); ++i) { + if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { + return false; + } + } + } + + const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; + const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; + + const float * op_params = (const float *)softmax->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + return false; + } + + // Check that the nodes don't have any unexpected uses + const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; + const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + const ggml_tensor * view = cgraph->nodes[node_idx + 3]; + const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; + const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; + const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; + const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; + + // softmax is used by reshape and argsort + if (ggml_node_get_use_count(cgraph, node_idx) != 2 || + reshape1->src[0] != softmax || + argsort->src[0] != softmax) { + return false; + } + // reshape is used by get_rows + if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || + get_rows->src[0] != reshape1) { + return false; + } + // argsort is used by view + if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || + view->src[0] != argsort) { + return false; + } + // view is written (via argsort), we can skip checking it + + if (with_norm) { + // get_rows is used by reshape + if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || + reshape5->src[0] != get_rows) { + return false; + } + + // reshape is used by sum_rows and div + if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || + sum_rows->src[0] != reshape5 || + div->src[0] != reshape5) { + return false; + } + + // sum_rows is used by div + if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || + div->src[1] != sum_rows) { + return false; + } + + // div/reshape are written + if (reshape8->src[0] != div) { + return false; + } + } + + if (!ctx->device->subgroup_arithmetic || + !ctx->device->subgroup_shuffle || + !ctx->device->subgroup_require_full_support || + ctx->device->disable_fusion) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12216,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { + ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { + ctx->num_additional_fused_ops = topk_moe.size() - 1; } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -12313,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { + ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; + } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { + ctx->num_additional_fused_ops = topk_moe.size() - 1; } } @@ -12320,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i + ctx->num_additional_fused_ops == last_node) || + (i + ctx->num_additional_fused_ops >= last_node) || (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); if (vk_perf_logger_enabled) { if (ctx->compute_ctx.expired()) { @@ -12444,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * while (first_unused < graph->n_nodes) { std::vector current_set; + // Avoid reordering topk_moe_norm + if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { + bool is_topk_moe_norm = true; + for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { + is_topk_moe_norm = false; + } + } + if (is_topk_moe_norm) { + for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + new_order.push_back(graph->nodes[first_unused + j]); + used[first_unused + j] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + continue; + } + } // First, grab the next unused node. current_set.push_back(first_unused); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp new file mode 100644 index 00000000..9e56d5f8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -0,0 +1,139 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint n_rows; + uint n_expert_used; +}; + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 2) const bool with_norm = true; + +const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + +layout (binding = 0, std430) readonly buffer Logits {float logits[];}; +layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; + +void main() { + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + if (row >= n_rows) { + return; + } + + const uint logits_offset = n_experts * row; + const uint weights_offset = n_expert_used * row; + const uint ids_offset = n_experts * row; + + float logits_r[experts_per_thread]; + + const float INFINITY = 1.0 / 0.0; + + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + gl_LocalInvocationID.x; + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; + } + + float max_val = logits_r[0]; + + [[unroll]] + for (int i = 1; i < experts_per_thread; i++) { + const float val = logits_r[i]; + max_val = max(val, max_val); + } + + max_val = subgroupMax(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const float val = logits_r[i]; + wt[i] = exp(val - max_val); + tmp += wt[i]; + } + + tmp = subgroupAdd(tmp); + + const float inv_sum = 1.0f / tmp; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + // at this point, each thread holds a portion of softmax, + // we do the argmax reduce over n_expert_used, each time marking + // the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + float output_weights[experts_per_thread]; + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + uint max_expert = gl_LocalInvocationID.x; + + [[unroll]] + for (int i = 1; i < experts_per_thread; i++) { + const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + + [[unroll]] + for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = subgroupShuffleXor(max_val, mask); + const uint expert = subgroupShuffleXor(max_expert, mask); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + output_weights[k / WARP_SIZE] = max_val; + } + + if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + + ids[ids_offset + k] = max_expert; + if (with_norm) { + wt_sum += max_val; + } + } + } + + if (with_norm) { + wt_sum = subgroupAdd(wt_sum); + const float inv_sum = 1.0f / wt_sum; + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + output_weights[i] *= inv_sum; + } + } + + [[unroll]] + for (uint i = 0; i < experts_per_thread; ++i) { + uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + if (idx < n_expert_used) { + weights[weights_offset + idx] = output_weights[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 1d04a812..49bf6c76 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -920,6 +920,8 @@ void process_shaders() { string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_moe_f32", "topk_moe.comp", {}); + for (auto &c : compiles) { c.wait(); } From 72d98011dbd1e668a2711da93343dd3cc9319389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 18 Oct 2025 14:47:32 +0200 Subject: [PATCH 095/104] HIP: fix GPU_TARGETS (llama/16642) --- ggml/src/ggml-hip/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 934aefdc..6b499320 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -28,8 +28,10 @@ if (CXX_IS_HIPCC) " Prefer setting the HIP compiler directly. See README for details.") endif() else() - # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES. - if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS}) + elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) endif() cmake_minimum_required(VERSION 3.21) From 82bdf31267143f62d71cd10e2684c3c7bec77c63 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Mon, 20 Oct 2025 05:06:39 +0800 Subject: [PATCH 096/104] ci : fix binaries release failure for s390x (binaries may not work yet) (llama/16664) * devops: initial patch Signed-off-by: Aaron Teo * devops: forgot the z15 suffix Signed-off-by: Aaron Teo * devops: attempt at impl GGML_CPU_ALL_VARIANTS for s390x Signed-off-by: Aaron Teo * devops: rm baseline version Signed-off-by: Aaron Teo --------- Signed-off-by: Aaron Teo --- ggml/src/CMakeLists.txt | 12 +++++++ ggml/src/ggml-cpu/CMakeLists.txt | 56 ++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 892c2331..3356ef55 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -307,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name) foreach (feat ${ARGN}) set(GGML_INTERNAL_${feat} ON) endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() endif() ggml_add_cpu_backend_variant_impl(${tag_name}) @@ -371,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS) else() message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}") endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE) + # ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE) + # ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE) + else() + message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}") + endif() else() message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") endif() diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 42041b71..34323afa 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -466,29 +466,45 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") - list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) - file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) - string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/arch/s390/quants.c) - # TODO: Separation to determine activation of VX/VXE/VXE2 - if (${S390X_M} MATCHES "8561|8562") - message(STATUS "z15 target") - list(APPEND ARCH_FLAGS -march=z15) - elseif (${S390X_M} MATCHES "3931") - message(STATUS "z16 target") - list(APPEND ARCH_FLAGS -march=z16) - elseif (${S390X_M} MATCHES "9175|9176") - # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. - # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. - message(STATUS "z17 target") - list(APPEND ARCH_FLAGS -march=arch15) - else() - message(STATUS "Unknown target") - message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") - list(APPEND ARCH_FLAGS -march=native -mtune=native) + # for native compilation + if (GGML_NATIVE) + # check machine level to determine target + file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) + string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) + + # TODO: Separation to determine activation of VX/VXE/VXE2 + if (${S390X_M} MATCHES "8561|8562") + message(STATUS "z15 target") + list(APPEND ARCH_FLAGS -march=z15) + elseif (${S390X_M} MATCHES "3931") + message(STATUS "z16 target") + list(APPEND ARCH_FLAGS -march=z16) + elseif (${S390X_M} MATCHES "9175|9176") + # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. + # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. + message(STATUS "z17 target") + list(APPEND ARCH_FLAGS -march=arch15) + else() + message(STATUS "Unknown target") + message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") + list(APPEND ARCH_FLAGS -march=native -mtune=native) + endif() + # for cross-compilation + elseif(GGML_CPU_ALL_VARIANTS) + # range through IBM z15 to z17 + # NOTE: update when a new hardware level is released + foreach (ZHW RANGE 15 17) + if(DEFINED GGML_INTERNAL_Z${ZHW}) + message(STATUS "z${ZHW} cross-compile target") + list(APPEND ARCH_FLAGS -march=z${ZHW}) + endif() + endforeach() endif() - if (GGML_VXE) + if (GGML_VXE OR GGML_INTERNAL_VXE) message(STATUS "VX/VXE/VXE2 enabled") list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_DEFINITIONS GGML_VXE) From bb76672081889215c673c6f0b8f8e4ae53735864 Mon Sep 17 00:00:00 2001 From: safranowith Date: Mon, 20 Oct 2025 11:08:32 +0300 Subject: [PATCH 097/104] SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators (llama/16613) * SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators Clean up unrelated changes from previous commit * Chore: remove empty lines and fix indentation * Clean up: remove leftover blank lines and fix spacing * chore: fix trailing whitespace and ensure final newline * Cleanup: remove redundant declarations already defined in header * Sync docs/ops.md with updated backend operation support * docs: update ops.md after rebase * docs: update ops.md - Vulkan supports SSM_CONV and SSM_SCAN --- ggml/src/ggml-sycl/element_wise.cpp | 120 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/element_wise.hpp | 4 + ggml/src/ggml-sycl/ggml-sycl.cpp | 16 ++++ 3 files changed, 140 insertions(+) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 58f5125c..810995d0 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { return x < static_cast(min_val) ? static_cast(min_val) : (x > static_cast(max_val) ? static_cast(max_val) : x); } +template +static __dpct_inline__ T op_floor(T x) { + return sycl::floor(x); +} + +template +static __dpct_inline__ T op_ceil(T x) { + return sycl::ceil(x); +} + +template +static __dpct_inline__ T op_round(T x) { + return sycl::round(x); +} + +template +static __dpct_inline__ T op_trunc(T x) { + return sycl::trunc(x); +} + template static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { SYCL_GLOBAL_ID_LOOP(k, item_ct1) { @@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl: } } +template +static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_floor(x[i]); + } +} + +template +static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_ceil(x[i]); + } +} + +template +static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_round(x[i]); + } +} + +template +static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_trunc(x[i]); + } +} + template static void upscale(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, @@ -897,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens }, min_val, max_val); } +static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); @@ -1122,3 +1222,23 @@ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst); } + +void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_floor(ctx, dst); +} + +void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_ceil(ctx, dst); +} + +void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_round(ctx, dst); +} + +void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_trunc(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index ed96c55f..fcf93295 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -80,6 +80,10 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a7e077ec..1a007ffe 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3698,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_sycl_elu(ctx, dst); break; + case GGML_UNARY_OP_FLOOR: + ggml_sycl_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_sycl_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_sycl_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_sycl_trunc(ctx, dst); + break; default: return false; } @@ -4262,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); #else From 70b4d22f01ce91ef6c2bf6231ac4c6c6af8ff670 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Mon, 20 Oct 2025 05:53:50 -0700 Subject: [PATCH 098/104] ggml-alloc : fix leak when reusing a tensor with a larger size (llama/16679) --- ggml/src/ggml-alloc.c | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 929bc448..c830c096 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -598,6 +598,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; } +// free the extra space at the end if the new tensor is smaller +static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) { + struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); + struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); + + size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent); + size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); + + GGML_ASSERT(parent_size >= node_size); + + if (parent_size > node_size) { + struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id]; + struct buffer_address p_addr = p_hn->addr; + p_addr.offset += node_size; + size_t extra_size = parent_size - node_size; + AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name); + ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent); + } +} + static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); @@ -643,6 +663,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor hn->addr = p_hn->addr; p_hn->allocated = false; // avoid freeing the parent view_src_hn->allocated = false; + ggml_gallocr_free_extra_space(galloc, node, view_src); return; } } else { @@ -650,6 +671,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor hn->buffer_id = p_hn->buffer_id; hn->addr = p_hn->addr; p_hn->allocated = false; // avoid freeing the parent + ggml_gallocr_free_extra_space(galloc, node, parent); return; } } From 55cf00c20a1f10492a54ccb6a98043cd9444fbc5 Mon Sep 17 00:00:00 2001 From: YehuditE Date: Tue, 21 Oct 2025 01:21:12 +0300 Subject: [PATCH 099/104] sycl : add PAD_REFLECT_D1 operator support (llama/16145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sycl: add PAD_REFLECT_D1 operator support * docs(ops): regenerate docs/ops.md * remove trailing whitespaces * style: fix editorconfig issues — trim trailing spaces and normalize EOLs * fix: move PAD_REFLECT_1D case outside of fall-through block --- ggml/src/ggml-sycl/backend.hpp | 2 + ggml/src/ggml-sycl/ggml-sycl.cpp | 5 ++ ggml/src/ggml-sycl/pad_reflect_1d.cpp | 72 +++++++++++++++++++++++++++ ggml/src/ggml-sycl/pad_reflect_1d.hpp | 8 +++ 4 files changed, 87 insertions(+) create mode 100644 ggml/src/ggml-sycl/pad_reflect_1d.cpp create mode 100644 ggml/src/ggml-sycl/pad_reflect_1d.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 6ff3215d..b1575b81 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -37,5 +37,7 @@ #include "softmax.hpp" #include "tsembd.hpp" #include "wkv.hpp" +#include "pad_reflect_1d.hpp" + #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1a007ffe..33f90350 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3744,6 +3744,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_CONCAT: ggml_sycl_op_concat(ctx, dst); break; + case GGML_OP_PAD_REFLECT_1D: + ggml_sycl_op_pad_reflect_1d(ctx,dst); + break; case GGML_OP_UPSCALE: ggml_sycl_upscale(ctx, dst); break; @@ -4455,6 +4458,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIV: case GGML_OP_REPEAT: return true; + case GGML_OP_PAD_REFLECT_1D: + return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.cpp b/ggml/src/ggml-sycl/pad_reflect_1d.cpp new file mode 100644 index 00000000..e56655a9 --- /dev/null +++ b/ggml/src/ggml-sycl/pad_reflect_1d.cpp @@ -0,0 +1,72 @@ +#include "pad_reflect_1d.hpp" + +void pad_reflect_1d_f32(const float* src,float* dst, + const int64_t ne0, const int64_t ne02, const int p0, const int p1, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const sycl::nd_item<3> &item_ct1){ + + const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); + const int i1 = item_ct1.get_group(1); + const int g2 = item_ct1.get_group(2); + const int i2 = g2 % ne02; + const int i3 = g2 / ne02; + + if (i0 >= p0 + ne0 + p1) return; + + int t = i0 - p0; + int period = 2 * ne0 -2; + int m = t % period; + m += (m < 0) * period; + int center = ne0 -1; + int srci0 = center - abs(center - m); + + int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; + int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; + dst[offest_dst] = src[offest_src]; + +} + +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ + + const ggml_tensor * src0 = dst->src[0]; + queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + const int64_t ne0 = src0->ne[0]; + + const int64_t ne00 = dst->ne[0]; + const int64_t ne01 = dst->ne[1]; + const int64_t ne02 = dst->ne[2]; + const int64_t ne03 = dst->ne[3]; + + const int64_t nb00 = dst->nb[0]; + const int64_t nb01 = dst->nb[1]; + const int64_t nb02 = dst->nb[2]; + const int64_t nb03 = dst->nb[3]; + const int64_t nb0 = src0->nb[0]; + const int64_t nb1 = src0->nb[1]; + const int64_t nb2 = src0->nb[2]; + const int64_t nb3 = src0->nb[3]; + + int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; + sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); + sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); + + stream->parallel_for( + sycl::nd_range<3>(global, + local), + [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( + (const float *) src0->data, (float *) dst->data, + ne0, ne02, p0, p1, + nb0, nb1, nb2, nb3, + nb00, nb01, nb02, nb03 + , item_ct1); + }); +} diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.hpp b/ggml/src/ggml-sycl/pad_reflect_1d.hpp new file mode 100644 index 00000000..a24509de --- /dev/null +++ b/ggml/src/ggml-sycl/pad_reflect_1d.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP +#define GGML_SYCL_PAD_REFLECT_1D_HPP + +#include "common.hpp" + +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + +#endif // GGML_SYCL_PAD_REFLECT_1D_HPP From 7f16c7106851f299f5f2e853795ec689fae9a752 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 20 Oct 2025 22:16:08 -0500 Subject: [PATCH 100/104] vulkan: Handle FA with all -inf mask values (llama/16447) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 6 +++++- .../vulkan-shaders/flash_attn_split_k_reduce.comp | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 62acbf10..2255f9c1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -345,7 +345,7 @@ void main() { float Lfrcp[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 2066a05b..8699fa6c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -380,7 +380,7 @@ void main() { float Lfrcp[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lfrcp[r] = 1.0 / Lf[r]; + Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 910da1ab..fcfc60a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -121,7 +121,11 @@ void main() { const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); L = coopmat(0); +#if defined(ACC_TYPE_MAX) + M = coopmat(-ACC_TYPE_MAX / ACC_TYPE(2)); +#else M = coopmat(NEG_FLT_MAX_OVER_2); +#endif coopmat slopeMat = coopmat(1.0); @@ -294,7 +298,7 @@ void main() { [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { - Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } O = Ldiag*O; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 06e83822..4eaddd31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -91,7 +91,7 @@ void main() { L = L*ms + vs; } - L = 1.0 / L; + L = (L == 0.0) ? 0.0 : 1.0 / L; // D dimension is split across workgroups in the y dimension uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; From 5c4c477d00bc6f2bfac27c499e09993aaa54b568 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 20 Oct 2025 22:26:17 -0700 Subject: [PATCH 101/104] opencl: fix warnings and clean up profiling (llama/16688) * opencl: remove unused headers, fix warnings * opencl: clean up profiling, only keep kernel time --- ggml/src/ggml-opencl/ggml-opencl.cpp | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d9876e69..db33a4ab 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -15,13 +15,12 @@ #include +#include #include #include #include -#include #include -#include #include #include #include @@ -533,25 +532,17 @@ struct ggml_backend_opencl_context { } // Dump a csv - float total_kernel_time = 0; - fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n"); + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); for (const ProfilingInfo & info : profiling_info) { - total_kernel_time += info.cmd_duration_ns/1.e6f; - fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", + fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), - info.cmd_queued_duration_ns/1.e6f, - info.cmd_submit_duration_ns/1.e6f, info.cmd_duration_ns/1.e6f, - info.cmd_complete_duration_ns/1.e6f, - info.cmd_total_duration_ns/1.e6f, info.global_size[0], info.global_size[1], info.global_size[2], info.local_size[0], info.local_size[1], info.local_size[2], info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); } fclose(fperf); - GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time); - // Dump a simple chrome trace FILE* ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { @@ -561,14 +552,14 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "[\n"); for (const ProfilingInfo & info : profiling_info) { - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_submit/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_start/1000); - fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n", + fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } fclose(ftrace); @@ -7652,6 +7643,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong nb21 = src2->nb[1]; const cl_ulong nb20 = src2->nb[0]; + UNUSED(nb20); + const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; From 9a8cfb040ccabbbc94ae0ba9d3edb053e671a617 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 21 Oct 2025 16:43:14 +0800 Subject: [PATCH 102/104] ggml: add ggml_can_fuse_subgraph (llama/16662) * ggml: add ggml_can_fuse_subgraph * ggml-cuda: use ggml_can_fuse_subgraph for topk-moe * format * 1. remove inputs from signature as they are transient nodes 2. add check for views: view_src should be part of the subgraph * - combine check into one loop - check all view_src parents - other minor review comments * remove redudant if test * - rename and other minor review comments * add assert about count < 32 --- ggml/src/ggml-cuda/ggml-cuda.cu | 23 ++--------- ggml/src/ggml-impl.h | 37 +++++++++++++++++ ggml/src/ggml.c | 72 +++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db1..015b37be 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2821,15 +2821,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); - if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { - - if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) { - return false; - } - - for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; - } + if (ops.size() == topk_moe_ops_with_norm.size() && + ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+8]; @@ -2838,16 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { - - if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) { - return false; - } - - for (size_t i = 0; i < topk_moe_ops.size(); i++) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; - } - + if (ops.size() == topk_moe_ops.size() && + ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 18f095b8..e9201cdc 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); } +GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, + const int * node_idxs, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs); + +// Returns true if the subgraph formed by {node_idxs} can be fused +// checks whethers all nodes which are not part of outputs can be elided +// by checking if their num_uses are confined to the subgraph +static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, + int node_idx, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs) { + GGML_ASSERT(count < 32); + if (node_idx + count > cgraph->n_nodes) { + return false; + } + + int idxs[32]; + + for (int i = 0; i < count; ++i) { + idxs[i] = node_idx + i; + } + + return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs); +} + #ifdef __cplusplus } #endif @@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std:: return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); } +inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, + int start_idx, + std::initializer_list ops, + std::initializer_list outputs = {}) { + return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 86f1c31a..9be35c1b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO("========================================\n"); } +static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph, + const int * idxs, + int count, + const struct ggml_tensor * tensor) { + GGML_ASSERT(cgraph && idxs); + for (int i = 0; i < count; ++i) { + const int node_idx = idxs[i]; + + if (node_idx >= cgraph->n_nodes) { + return -1; + } + if (cgraph->nodes[node_idx] == tensor) { + return i; + } + } + return -1; +} + +bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, + const int * node_idxs, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs) { + GGML_ASSERT(outputs && num_outputs > 0); + + for (int i = 0; i < count; ++i) { + if (node_idxs[i] >= cgraph->n_nodes) { + return false; + } + + const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]]; + + if (node->op != ops[i]) { + return false; + } + + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { + continue; + } + + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + return false; + } + + int subgraph_uses = 0; + for (int j = i + 1; j < count; ++j) { + const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]]; + for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) { + if (other_node->src[src_idx] == node) { + subgraph_uses++; + } + } + } + + if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) { + return false; + } + + // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph + struct ggml_tensor * view_src = node->view_src; + while (view_src) { + if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) { + return false; + } + view_src = view_src->view_src; + } + } + + return true; +} + // check if node is part of the graph static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { if (cgraph == NULL) { From 35ea5ced60512093e08350ab46698b0c5d71c934 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 Oct 2025 08:28:23 +0300 Subject: [PATCH 103/104] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 524e2b1c..aaceb7c5 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c538174d261d8172480f87efcfec8e69aac13ebb +999574b730626d57f7ad24a06074ac169e851dfa From 322c2adb753a9506f0becee134a7f75e2a6b5687 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 Oct 2025 08:32:16 +0300 Subject: [PATCH 104/104] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 40 +++ examples/talk-llama/llama-arch.h | 4 + examples/talk-llama/llama-batch.h | 2 +- examples/talk-llama/llama-chat.cpp | 37 ++- examples/talk-llama/llama-chat.h | 2 + examples/talk-llama/llama-context.cpp | 3 +- examples/talk-llama/llama-graph.cpp | 30 +++ examples/talk-llama/llama-hparams.h | 2 + examples/talk-llama/llama-model.cpp | 341 +++++++++++++++++++++++--- examples/talk-llama/llama-model.h | 3 + examples/talk-llama/llama-quant.cpp | 8 +- examples/talk-llama/llama-vocab.cpp | 1 + examples/talk-llama/llama.cpp | 3 + 13 files changed, 431 insertions(+), 45 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 869e4dcc..8ca769c5 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -5,6 +5,7 @@ #include static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, @@ -84,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_BAILINGMOE2, "bailingmoe2" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, @@ -134,6 +136,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" }, + { LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, @@ -275,6 +279,10 @@ static const std::map LLM_KV_NAMES = { }; static const std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_CLIP, + {}, + }, { LLM_ARCH_LLAMA, { @@ -1941,6 +1949,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_BAILINGMOE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_DOTS1, { diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index c3ae7165..dea725c1 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -9,6 +9,7 @@ // enum llm_arch { + LLM_ARCH_CLIP, LLM_ARCH_LLAMA, LLM_ARCH_LLAMA4, LLM_ARCH_DECI, @@ -88,6 +89,7 @@ enum llm_arch { LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, + LLM_ARCH_BAILINGMOE2, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, @@ -138,6 +140,8 @@ enum llm_kv { LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_GROUP_COUNT, + LLM_KV_EXPERT_GROUP_USED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index d563adc6..0dc8cebd 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -123,7 +123,7 @@ private: uint32_t n_seq_max; uint32_t n_outputs; - std::array seq_id_0 = { 0 }; // default sequence id + std::array seq_id_0 = {{ 0 }}; // default sequence id std::vector pos; std::vector n_seq_id; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 956c4e08..0285006d 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -63,6 +63,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK }, + { "bailing2", LLM_CHAT_TEMPLATE_BAILING2 }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, @@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_YANDEX; } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { return LLM_CHAT_TEMPLATE_BAILING; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("\"HUMAN\"") && tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_BAILING_THINK; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("HUMAN") && tmpl_contains("<|role_end|>")) { + return LLM_CHAT_TEMPLATE_BAILING2; } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { @@ -644,8 +650,8 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << " Ассистент:[SEP]"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { - // Bailing (Ling) template + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) { + // Bailing (Ling/Ring) template for (auto message : chat) { std::string role(message->role); @@ -658,6 +664,33 @@ int32_t llm_chat_apply_template( ss << "" << role << "" << message->content; } + if (add_ass) { + ss << "ASSISTANT"; + + if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) { + ss << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) { + // Bailing2 (Ling 2.0) template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + if (!has_system) { + ss << "SYSTEMdetailed thinking off<|role_end|>"; + } + + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content << "<|role_end|>"; + } + if (add_ass) { ss << "ASSISTANT"; } diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 5a87d9ab..da1b7c47 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -42,6 +42,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, LLM_CHAT_TEMPLATE_BAILING, + LLM_CHAT_TEMPLATE_BAILING_THINK, + LLM_CHAT_TEMPLATE_BAILING2, LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index e7526e7d..bd348bca 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2346,7 +2346,8 @@ llama_context * llama_init_from_model( return nullptr; } - if (params.pooling_type != model->hparams.pooling_type) { + if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && + params.pooling_type != model->hparams.pooling_type) { //user-specified pooling-type is different from the model default LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, model->hparams.pooling_type, params.pooling_type); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index f29a1e98..41fa6894 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -950,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(selection_probs, "ffn_moe_probs_biased", il); } + // select top n_group_used expert groups + // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457 + if (hparams.n_expert_groups > 1 && n_tokens > 0) { + const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; + + // organize experts into n_expert_groups + ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] + + ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] + group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens] + + // get top n_group_used expert groups + group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens] + group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens] + + ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens] + cb(expert_groups, "ffn_moe_group_topk", il); + + // mask out the other groups + selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens] + selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] + selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens] + cb(selection_probs, "ffn_moe_probs_masked", il); + } + // select experts ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); @@ -981,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); + if (arch == LLM_ARCH_BAILINGMOE2) { + weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20); + cb(weights_sum, "ffn_moe_weights_sum_biased", il); + } + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 4e7f73ec..6fcf91b7 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -72,6 +72,8 @@ struct llama_hparams { uint32_t n_ff_chexp = 0; uint32_t n_expert_shared = 0; uint32_t n_norm_groups = 0; + uint32_t n_expert_groups = 0; + uint32_t n_group_used = 0; uint32_t n_group_experts = 0; float expert_group_scale = 0.05f; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0cdad9ba..e4609963 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -114,9 +114,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; @@ -421,11 +424,8 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored - std::vector ctxs; - - // the model memory buffers for the tensor data - std::vector bufs; + // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + std::vector> ctxs_bufs; buft_list_t cpu_buft_list; std::map gpu_buft_list; @@ -478,15 +478,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_GENERAL_NAME, name, false); // everything past this point is not vocab-related - if (hparams.vocab_only) { + // for CLIP models, we only need to load tensors, no hparams + if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) { return; } - ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); - ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); + ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); @@ -502,8 +505,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); if (hparams.n_expert > 0) { GGML_ASSERT(hparams.n_expert_used > 0); + GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert); + if (hparams.n_expert_groups > 1) { + GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0); + GGML_ASSERT(hparams.n_group_used > 0); + GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups); + } } else { GGML_ASSERT(hparams.n_expert_used == 0); + GGML_ASSERT(hparams.n_expert_groups == 0); } std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); @@ -1845,8 +1855,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - // TODO: Add llm type label (not sure this is useful) + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_7B_A1B; break; + case 2048: case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -1887,6 +1899,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_BAILINGMOE2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 20: type = LLM_TYPE_16B_A1B; break; + case 21: type = LLM_TYPE_16B_A1B; break; + case 32: type = LLM_TYPE_100B_A6B; break; + case 33: type = LLM_TYPE_100B_A6B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2181,7 +2216,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { max_n_tensors += n_layer*2; // duplicated rope freq tensors const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; - std::map ctx_map; + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + } + }; + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -2196,12 +2238,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error(format("failed to create ggml context")); } - ctx_map[buft] = ctx; - pimpl->ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; @@ -5491,6 +5532,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_BAILINGMOE2: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); + } + } + } break; case LLM_ARCH_DOTS1: { const int64_t n_ff_exp = hparams.n_ff_exp; @@ -6036,16 +6141,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { pimpl->mappings.reserve(ml.mappings.size()); // create the backend buffers - std::vector> ctx_bufs; - ctx_bufs.reserve(ctx_map.size()); + std::vector> ctx_buf_maps; + ctx_buf_maps.reserve(ctx_map.size()); // Ensure we have enough capacity for the maximum backend buffer we will potentially create const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); - pimpl->bufs.reserve(n_max_backend_buffer); + pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - for (auto & it : ctx_map) { - ggml_backend_buffer_type_t buft = it.first; - ggml_context * ctx = it.second; + for (auto & [buft, ctx_ptr] : ctx_map) { + ggml_context * ctx = ctx_ptr.get(); // skip contexts without tensors if (ggml_get_first_tensor(ctx) == nullptr) { @@ -6069,6 +6173,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); + ggml_backend_buffer_t buf = nullptr; if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { // only the mmap region containing the tensors in the model is mapped to the backend buffer @@ -6081,20 +6186,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) { continue; } const size_t max_size = ggml_get_max_tensor_size(ctx); - ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - pimpl->bufs.emplace_back(buf); buf_map.emplace(idx, buf); } } else { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - pimpl->bufs.emplace_back(buf); if (use_mlock && ggml_backend_buffer_is_host(buf)) { pimpl->mlock_bufs.emplace_back(new llama_mlock); auto & mlock_buf = pimpl->mlock_bufs.back(); @@ -6105,10 +6208,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { buf_map.emplace(idx, buf); } } - - if (pimpl->bufs.empty()) { - throw std::runtime_error("failed to allocate buffer"); - } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf); for (auto & buf : buf_map) { // indicate that this buffer contains weights @@ -6116,7 +6216,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } - ctx_bufs.emplace_back(ctx, buf_map); + ctx_buf_maps.emplace_back(ctx, buf_map); } if (llama_supports_gpu_offload()) { @@ -6134,22 +6234,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // print memory requirements per buffer type - for (auto & buf : pimpl->bufs) { + for (auto & [_, buf] : pimpl->ctxs_bufs) { LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); } // populate tensors_by_name - for (auto & ctx : pimpl->ctxs) { + for (auto & [ctx, _] : pimpl->ctxs_bufs) { for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { tensors_by_name.emplace_back(ggml_get_name(cur), cur); } } // load tensor data - for (auto & it : ctx_bufs) { - ggml_context * ctx = it.first; - auto & bufs = it.second; - if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + for (auto & [ctx, buf_map] : ctx_buf_maps) { + if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { return false; } } @@ -6189,8 +6287,8 @@ size_t llama_model::n_devices() const { std::map llama_model::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : pimpl->ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -6353,6 +6451,19 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); @@ -17042,6 +17153,150 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; +struct llm_build_bailingmoe2 : public llm_graph_context { + llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA); + cb(sa_out, "sa_out", il); + + // MoE branch + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if (static_cast(il) < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_dots1 : public llm_graph_context { llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19838,6 +20093,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_BAILINGMOE2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SEED_OSS: { llm = std::make_unique(*this, params); @@ -20013,6 +20272,7 @@ int32_t llama_n_head(const llama_model * model) { llama_rope_type llama_model_rope_type(const llama_model * model) { switch (model->arch) { // these models do not use RoPE + case LLM_ARCH_CLIP: case LLM_ARCH_GPT2: case LLM_ARCH_GPTJ: case LLM_ARCH_MPT: @@ -20103,6 +20363,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: + case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_OPENAI_MOE: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 7f48662f..248f8541 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -107,9 +107,12 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 97228b2a..6dd40412 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } + bool is_clip_model = false; for (const auto * it : tensors) { const struct ggml_tensor * tensor = it->tensor; @@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { qs.has_output = true; } + + is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix } qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks for models that have attention layers - if (qs.n_attention_wv != 0) + if (qs.n_attention_wv != 0 && !is_clip_model) { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads @@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; + // do not quantize specific multimodal tensors + quantize &= name.find(".position_embd.") == std::string::npos; + ggml_type new_type; void * new_data; size_t new_size; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 7fffd171..639fecbd 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -1968,6 +1968,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || + tokenizer_pre == "bailingmoe2" || tokenizer_pre == "llada-moe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 38700f97..ab2e9868 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector } catch(const std::exception & e) { throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); } + if (model.arch == LLM_ARCH_CLIP) { + throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); + } try { model.load_vocab(ml); } catch(const std::exception & e) {